{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7000, 500]) torch.Size([7000, 1]) torch.Size([3000, 500]) torch.Size([3000, 1])\n",
      "epoch 1,train_loss 0.017197,test_loss 0.016579\n",
      "epoch 2,train_loss 0.016889,test_loss 0.016249\n",
      "epoch 3,train_loss 0.016663,test_loss 0.016006\n",
      "epoch 4,train_loss 0.016495,test_loss 0.015823\n",
      "epoch 5,train_loss 0.016368,test_loss 0.015684\n",
      "epoch 6,train_loss 0.016271,test_loss 0.015577\n",
      "epoch 7,train_loss 0.016194,test_loss 0.015492\n",
      "epoch 8,train_loss 0.016132,test_loss 0.015424\n",
      "epoch 9,train_loss 0.016081,test_loss 0.015369\n",
      "epoch 10,train_loss 0.016037,test_loss 0.015322\n",
      "epoch 11,train_loss 0.015999,test_loss 0.015282\n",
      "epoch 12,train_loss 0.015964,test_loss 0.015246\n",
      "epoch 13,train_loss 0.015933,test_loss 0.015213\n",
      "epoch 14,train_loss 0.015903,test_loss 0.015184\n",
      "epoch 15,train_loss 0.015875,test_loss 0.015156\n",
      "epoch 16,train_loss 0.015847,test_loss 0.015130\n",
      "epoch 17,train_loss 0.015821,test_loss 0.015105\n",
      "epoch 18,train_loss 0.015795,test_loss 0.015080\n",
      "epoch 19,train_loss 0.015769,test_loss 0.015057\n",
      "epoch 20,train_loss 0.015744,test_loss 0.015033\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAEGCAYAAACtqQjWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU5dn/8c+VnSwkIQs7JOwoYFhkFUVRBFRw3xWsirbS6tOfttjWPpXaVh+tW0tVVHCtoiiKiqIsLqgsAcMeIIQIYQlJgJAEAlnu3x/nhAwhE5IwM2eSXO/X67wyOdtcM0nmm/ss9y3GGJRSSilPCHC6AKWUUk2HhopSSimP0VBRSinlMRoqSimlPEZDRSmllMcEOV2Ak+Lj401SUpLTZSilVKOyevXqPGNMQk3LmnWoJCUlkZqa6nQZSinVqIjIz+6W6eEvpZRSHqOhopRSymM0VJRSSnlMsz6nopRqekpLS8nOzqakpMTpUhq9sLAwOnToQHBwcJ230VBRSjUp2dnZREVFkZSUhIg4XU6jZYwhPz+f7OxskpOT67ydHv5SSjUpJSUlxMXFaaCcIREhLi6u3i0+DRWlVJOjgeIZDXkfNVQaYPPewzzxRTo6bIBSSp1MQ6UBVmTm88LX21m8eb/TpSillF/RUGmAW4Z2pmtCBH9fsJnjZRVOl6OU8iOHDh3iP//5T723Gz9+PIcOHar3dpMnT2bu3Ln13s5bNFQaIDgwgD9ddhaZecW8udxtbwVKqWbIXaiUl5fXut2CBQuIiYnxVlk+o5cUN9Congmc3yOB5xZt5ar+7WkVEeJ0SUqpah79ZCOb9hz26D7PateS/73ibLfLp02bxvbt20lJSSE4OJjIyEjatm1LWloamzZt4sorr2TXrl2UlJRw//33M2XKFKCqL8KioiLGjRvHeeedxw8//ED79u35+OOPadGixWlrW7x4MQ8++CBlZWWce+65vPDCC4SGhjJt2jTmz59PUFAQY8aM4amnnuL999/n0UcfJTAwkOjoaL799luPvD/aUmkgEeFPl/Wm+Hg5zy7a6nQ5Sik/8fjjj9O1a1fS0tJ48sknWblyJX/729/YtGkTALNmzWL16tWkpqby/PPPk5+ff8o+tm3bxn333cfGjRuJiYnhgw8+OO3zlpSUMHnyZObMmcP69espKyvjhRde4MCBA8ybN4+NGzeybt06/vSnPwEwffp0Fi5cyNq1a5k/f77HXr+2VM5Aj9ZR3Dy4E2+v2MltQzvTvXWU0yUppVzU1qLwlcGDB5908+Dzzz/PvHnzANi1axfbtm0jLi7upG2Sk5NJSUkBYODAgWRlZZ32ebZs2UJycjI9evQAYNKkScyYMYOpU6cSFhbGXXfdxWWXXcbll18OwIgRI5g8eTLXX389V199tSdeKqAtlTP2P5f0IDwkkMc+2+x0KUopPxQREXHi8ddff82iRYv48ccfWbt2Lf3796/x5sLQ0NATjwMDAykrKzvt87i7xSEoKIiVK1dyzTXX8NFHHzF27FgAXnzxRR577DF27dpFSkpKjS2mhtBQOUOtIkK4f3R3vtmay9IteomxUs1dVFQUhYWFNS4rKCggNjaW8PBw0tPTWb58uceet1evXmRlZZGRkQHAm2++yQUXXEBRUREFBQWMHz+eZ599lrS0NAC2b9/OkCFDmD59OvHx8ezatcsjdejhLw+4fVgSb6/Yyd8+28x53eIJDtSsVqq5iouLY8SIEfTp04cWLVrQunXrE8vGjh3Liy++SL9+/ejZsydDhw712POGhYUxe/ZsrrvuuhMn6u+9914OHDjAxIkTKSkpwRjDM888A8BDDz3Etm3bMMYwevRozjnnHI/UIc35rvBBgwYZT438+NWmHO5+I5VHJ5zNpOFJHtmnUqr+Nm/eTO/evZ0uo8mo6f0UkdXGmEE1ra//UnvIxb0TGdEtjmcWbeXQkeNOl6OUUo7QUPEQ6xLjszh8tJTnFm9zuhylVBNz3333kZKSctI0e/Zsp8s6hZ5T8aDebVtyw7mdePPHn7l1aGe6JkQ6XZJSqomYMWOG0yXUibZUPOz/jelBWHAgf9dLjJVSzZCGiofFR4by64u6sTh9P99ty3W6HKWU8imvhoqIjBWRLSKSISLTalgeKiJz7OUrRCTJnh8nIktFpEhE/u2yfpSIpLlMeSLybG37csLkEUl0ahXOY59upqxcezFWSjUfXgsVEQkEZgDjgLOAm0TkrGqr3QkcNMZ0A54BnrDnlwCPAA+6rmyMKTTGpFROwM/Ah6fZl8+FBgXyh/G92JJTyLurPHNDkVJKNQbebKkMBjKMMZnGmOPAu8DEautMBF63H88FRouIGGOKjTHLsMKlRiLSHUgEvqttX555KfV36dltGJLciqe/2krB0VKnylBK+VhDx1MBePbZZzly5Eit6yQlJZGXl9eg/fuCN0OlPeD6b3q2Pa/GdYwxZUABEEfd3ATMMVV3b9ZpXyIyRURSRSQ1N9d75zxEhEcuP4uDR44zY2mG155HKeVfvB0q/s6blxTX1Eqofvt+XdZx50bgtvruyxgzE5gJ1h31dXyuBunTPprrBnZg9vc7uHlwJ5LiI06/kVLKcz6fBvvWe3afbfrCuMfdLnYdT+WSSy4hMTGR9957j2PHjnHVVVfx6KOPUlxczPXXX092djbl5eU88sgj5OTksGfPHi688ELi4+NZunTpaUt5+umnmTVrFgB33XUXDzzwQI37vuGGG2ocU8UbvBkq2UBHl+87AHvcrJMtIkFANHDgdDsWkXOAIGPM6jPdl7c9OKYnn63by98XbGbm7TX2aqCUakIef/xxNmzYQFpaGl9++SVz585l5cqVGGOYMGEC3377Lbm5ubRr147PPvsMsDqajI6O5umnn2bp0qXEx8ef9nlWr17N7NmzWbFiBcYYhgwZwgUXXEBmZuYp+64cUyU9PR0RadCwxXXlzVBZBXQXkWRgN1bL4uZq68wHJgE/AtcCS0zdOiO7CXjHQ/vyqsSWYfzqwm48uXALP2zPY3jX0/+yKKU8pJYWhS98+eWXfPnll/Tv3x+AoqIitm3bxsiRI3nwwQf5/e9/z+WXX87IkSPrve9ly5Zx1VVXneha/+qrr+a7775j7Nixp+y7rKysxjFVvMFr51Ts8xpTgYXAZuA9Y8xGEZkuIhPs1V4F4kQkA/gtcOKyYxHJAp4GJotIdrUrx67n1FBxuy+n3XleMu1jWvDXTzdTXuF4zimlfMQYw8MPP0xaWhppaWlkZGRw55130qNHD1avXk3fvn15+OGHmT59eoP2XZOa9u1uTBWvMMY022ngwIHGVz5Zu9t0/v2n5p0VP/vsOZVqjjZt2uTo8+fl5ZlOnToZY4xZuHChGTx4sCksLDTGGJOdnW1ycnLM7t27zdGjR40xxsybN89MnDjRGGNMnz59TGZmZq3779y5s8nNzTWrV682ffv2NcXFxaaoqMicffbZZs2aNTXuu7Cw0OTk5BhjjMnPzzexsbF1fj01vZ9AqnHzuap9f/nIZX3b8lrnLJ76cguX9WtLVFiw0yUppbzAdTyVcePGcfPNNzNs2DAAIiMjeeutt8jIyOChhx4iICCA4OBgXnjhBQCmTJnCuHHjaNu27WlP1A8YMIDJkyczePBgwDpR379/fxYuXHjKvgsLC2scU8UbdDwVD42nUhdrdx1i4ozv+eWorvx+bC+fPa9SzYmOp+JZOp6KHzunYwxXD2jPq9/tYNeBxn0tulJK1URDxcd+d2kvAgOEvy/QXoyVUu4NGTLklPFT1q/38D03XqDnVHysTXQYvxzVlae/2srHabuZmFK9kwGl1JkyxuBgL00esWLFCqdLcHuFWW20peKAX47qyrlJsUz7YD1b9hU6XY5STUpYWBj5+fkN+kBUVYwx5OfnExYWVq/t9ES9D0/Uu9p/uITxzy+jZVgQH08doVeDKeUhpaWlZGdnU1Litj9aVUdhYWF06NCB4OCTP59qO1Gvh78cktgyjBk39+fmV1bw0PvreOHWAY2+ua6UPwgODiY5OdnpMpotPfzloCFd4pg2thdfbNzHy99lOl2OUkqdMQ0Vh901MpnxfdvwxBdbWJ6Z73Q5Sil1RjRUHCYiPHFNPzrHhTP1vz+Rc1iPAyulGi8NFT8QFRbMi7cOpPhYGfe9vYZSHddeKdVIaaj4iR6to3ji2n6k/nxQb4xUSjVaGip+ZMI57Zg8PInZ32cxf2318cyUUsr/aaj4mT+M783AzrFM+2Ad23L0xkilVOOioeJnQoICmHHzAMJDArnnrdUUHStzuiSllKozDRU/1CY6jH/dNICf84/wu7lrtbsJpVSjoaHip4Z1jeN3l/Zkwfp9vLpsh9PlKKVUnWio+LEp53dh7Nlt+Mfn6azQGyOVUo2AhoofExGevK4fnVuFM/Wdn9ivN0YqpfychoqfiwoL5sXbBlJUUsZ9/9UbI5VS/k1DpRHo0TqKx6/py6qsgzz+ebrT5SillFsaKo3ExJT2TB6exKvLdvDpOr0xUinlnzRUGpE/jO/NgE4x/G7uOjL2642RSin/o6HSiIQEBTDjlgG0CA7knjdXk190zOmSlFLqJBoqjUzb6BbMuGUA2QePcuPM5XpFmFLKr2ioNEJDu8Tx2h2D2X3oKNe/9CO7Dx11uiSllAI0VBqtYV3jeOuuIeQXH+f6F3/k5/xip0tSSikNlcZsQKdY3rl7KEeOl3Hdiz/qyXullOM0VBq5Pu2jmXPPMAxww0vL2binwOmSlFLNmIZKE9CjdRTv3TOM0KAAbpq5nJ92HnS6JKVUM6Wh0lCH9zpdwUmS4yN4795hxISHcOsrK7QDSqWUIzRUGuK7f8KMwVDsXx/cHWLDee+eYbSJDmPS7JV8uzXX6ZKUUs2MhkpD9LwMjhfBsqedruQUbaLDmHPPMJLjI7nr9VS+2pTjdElKqWZEQ6UhEnvBOTfDypfh0C6nqzlFfGQo7949lN7tWvLLt1bzyVrtK0wp5RsaKg01ahpg4JvHna6kRtHhwbx152AGdIrl/nd/Yu7qbKdLUko1AxoqDRXTEc69G9L+C7lbnK6mRlFhwbz+i8GM6BbPg++v5c0fs5wuSSnVxHk1VERkrIhsEZEMEZlWw/JQEZljL18hIkn2/DgRWSoiRSLy72rbhIjITBHZKiLpInKNPX+yiOSKSJo93eXN1wbAyN9CcAQs+avXn6qhWoQE8vLtg7i4dyKPfLyRmd9ud7okpVQT5rVQEZFAYAYwDjgLuElEzqq22p3AQWNMN+AZ4Al7fgnwCPBgDbv+I7DfGNPD3u83LsvmGGNS7OkVz70aNyLiYfivYfMnkL3a60/XUGHBgbxw60Au69eWvy9I57lF2zDGOF2WUqoJ8mZLZTCQYYzJNMYcB94FJlZbZyLwuv14LjBaRMQYU2yMWYYVLtX9AvgHgDGmwhiT553y62jYryA8Hhb9L/jxB3VwYADP39ifawZ04JlFW3n883QqKvy3XqVU4+TNUGkPuF4alW3Pq3EdY0wZUADEuduhiMTYD/8qImtE5H0Rae2yyjUisk5E5opIRzf7mCIiqSKSmpvrgfs4QqPg/Icg6zvIXHrm+/OiwADhyWv7cdvQzrz0bSaTZq/UrvOVUh7lzVCRGuZV/9e4Luu4CgI6AN8bYwYAPwJP2cs+AZKMMf2ARVS1gE7euTEzjTGDjDGDEhISaqu/7gbdATGdYNGjUFHhmX16SUCAMH3i2fzj6r6syjrAuOe+Y2n6fqfLUko1Ed4MlWzAtbXQAah+w8SJdUQkCIgGDtSyz3zgCDDP/v59YACAMSbfGFM5FOLLwMAzKb5egkLhwj/C3jTY9JHPnrahRISbBnfik6nnkRAVyh2vrWL6J5s4VlbudGlKqUbOm6GyCuguIskiEgLcCMyvts58YJL9+FpgianlDLK97BNglD1rNLAJQETauqw6Adh8pi+gXvpeB4lnwZLHoLzUp0/dUN1bR/HRfSOYPDyJWd/v4KoZP5Cxv8jpspRSjZjXQsU+RzIVWIj1Af+eMWajiEwXkQn2aq8CcSKSAfwWOHHZsYhkAU8Dk0Uk2+XKsd8DfxGRdcBtwP+z5/9GRDaKyFrgN8Bkb722GgUEwug/w4Ht8NNbPn3qMxEWHMhfJpzNK7cPYm/BUa741zLmrNqpV4cppRpEmvOHx6BBg0xqaqrndmgMzLoUDu2EX6+BkHDP7dsHcg6X8D9z0vhhe751+fFVfYluEex0WUopPyMiq40xg2papnfUe5IIXPwXKNwLK2c6XU29tW4Zxpt3DuF3Y3uycMM+xj/3Hat/ru0Ul1JKnUxDxdM6D4fuY6wejI82vsGyAgOEX43qxvv3DiMgAK5/aTnPL95Gud7TopSqAw0Vbxj9ZygpgO+fd7qSBuvfKZYFvxnJ5f3a8vRXW7n55eXsLTjqdFlKKT+noeINbfpaV4MtfwEK9zldTYNFhQXz7A0p/PO6c9iwu4Cxz37HFxsa7+tRSnmfhoq3XPgHqCiFb/7P6UrOiIhwzcAOfPqbkXRqFc69b63mj/PWU1Kq97QopU6loeItrbrAwDtgzeuQ3/h7Bk6Oj+CDXw7nnvO78PaKnVzxr2V8n+Fst2tKKf+joeJN5z8EgSGw9G9OV+IRIUEBPDy+N2/8YjBHjpdzyysr+MVrq9iWU+h0aUopP6Gh4k1RrWHor2DDB7B3rdPVeMz5PRJY/P8u4OFxvViVdYBLn/2WP8xbT27hsdNvrJRq0jRUvG3Eb6BFLCye7nQlHhUWHMg9F3Tlm4cu5PZhSby3ahejnlzKjKUZer5FqWZMQ8XbwqLhvN9CxiLY8Z3T1Xhcq4gQ/jLhbL78n/MZ0S2eJxdu4cKnvubDNdk6XotSzZCGii8Mvhui2sHiR/16IK8z0SUhkpm3D+LdKUOJjwzlt++tZcKMZfy4Pd/p0pRSPqSh4gvBLWDUNMheBVsWOF2NVw3tEsfH943g2RtSOFB0nJteXs5dr6/S3o+Vaia0Q0lPdihZm/Iy+M8QCAiCX/5g9WrcxJWUljPr+x38Z+l2jpaWc/PgTjxwcXfiIkOdLk0pdQa0Q0l/EBgEFz0Cuemw9l2nq/GJsOBAfjWqG18/NIqbB3fivyt3csGTX/Ofr/VkvlJNlbZUfNVSAet8yssXQnEeTE2F4DDfPbcfyNhfxOOfb2bR5v20iw7jF+clc/25HWkZpt3rK9WYaEvFX1R2jV+wC1JnOV2Nz3VLjOSVSefy37uH0CE2nMc+28ywvy/mL/M3kpVX7HR5SikP0JaKL1sqld6YCHvXwdRVEBHv++f3E+uzC5j9/Q4+WbeHsgrD6F6tufO8ZIZ2aYWIOF2eUsqN2loqGipOhErORpg5CrqOhpvesVowzdj+wyW8ufxn3l6xkwPFx+ndtiW/GJHEhJR2hAY1/QsalGpsNFTccCxUwOoW/4tpMP4p6z4WRUlpOR+n7WbWsiy25BQSHxnCrUM7c8uQziRE6RVjSvkLDRU3HA0VY+Dt62DHtzDla2h9ljN1+CFjDN9n5DPr+x0sSd9PSGAAE1PacceIZM5q19Lp8pRq9jRU3HA0VACKcuGF4dZ5lbuXWDdJqpNszy3ite+zmLs6m6Ol5QzrEsed5yVzUa9EAgKa92FDpZyioeKG46ECsG0RvH0NDJ4C4590thY/VnCklHdX7eT1H7LYU1BC57hwrurfnokp7UmOj3C6PKWaFQ0VN/wiVAC++AMsnwE3vQs9xzldjV8rK6/gi437eHv5TpbvyMcY6NchmgnntOOKc9rRumXzuvdHKSeccaiIyP3AbKAQeAXoD0wzxnzpyUJ9zW9CpewYvDIaCnZbXbi0bOt0RY3CvoISPl23h4/T9rB+dwEiMKxLHBNT2jG2T1uiW+hNlUp5gydCZa0x5hwRuRS4D3gEmG2MGeDZUn3Lb0IFIHcrvHQ+dBoCt86DAL0vtT625xYxP20PH6ftJiv/CCGBAYzqmcDElPaM7p1IWLBemqyUp3giVNYZY/qJyHPA18aYeSLykzGmv6eL9SW/ChWA1a/BJ/fDJdNhxP1OV9MoGWNYv7uAj37awyfr9pBbeIzI0CDGnN2aK1PaM7xrHEGBGthKnQlPhMpsoD2QDJwDBGKFy0BPFuprfhcqxsB7t1vd49/5FbRv1A1Bx5VXGJZn5vNx2m4+37CPwpIy4iNDuKxvWyaktKN/x1i9gkypBvBEqAQAKUCmMeaQiLQCOhhj1nm2VN/yu1ABOHIAXjwPgsLgnm8hNNLpipqEktJyvt6Sy/y1u1m0eT/HyyqIiwjhgp4JXNQrkZHdE/QcjFJ15IlQGQGkGWOKReRWYADwnDHmZ8+W6lt+GSoAWcvgtcsh5Ra4cobT1TQ5h0tKWbJ5P0vS9/PN1lwKjpYSFCAM7BzLRb0SuahXIt0SI7X/MaXc8Mg5FazDXv2AN4FXgauNMRd4slBf89tQAVjyGHz7JFw7C/pc43Q1TVZZeQU/7TrEkvT9LE3fT/q+QgA6xLbgol6JXNgrkWFd4vREv1IuPBEqa4wxA0Tkz8BuY8yrlfM8Xawv+XWolJfC7HHWVWG/XAYxnZyuqFnYfegoS9P38/WW/SzLyKOktIKw4ABGdI3nQjtk2sdozweqefNEqHwDfAH8AhgJ5GIdDuvryUJ9za9DBeDADnhxJLQ+GyZ/Zo0eqXympLSc5Zn5LE3fz5It+9l14CgAvdpEMapnIsO6xjGocywRofpzUc2LJ0KlDXAzsMoY852IdAJGGWPe8GypvuX3oQKw7n348C4Y9TCMmuZ0Nc2WMYbtuUUsSbfOxaRmHaSswhAYIPRtH82QLq0YmhzHoKRYonQkS9XEeaSbFhFpDZxrf7vSGLPfQ/U5plGECsCH98D69+COz6HTUKerUUDxsTJW/3yQFTvyWZF5gLXZhygtNwQI9GkfzZDkVgxJjuPc5FZ6VZlqcjzRUrkeeBL4GhCsQ2APGWPmerBOn2s0oVJyGF4aCRXlcO8yaBHjdEWqmqPHy1mz8yArMvNZvuMAaTsPcby8AhE4q21LhiTHMaRLK4YktyImPMTpcpU6Ix7ppgW4pLJ1IiIJwCJjzDkerdTHGk2oAGSnwqxLofcVcO3sZj9apL8rKS3np52HTrRk1uw8yLGyCsA6JzMkuRUDOscyoFMsHWJb6OXLqlGpLVTqeoYxoNrhrnzgtH1diMhY4DmsO/BfMcY8Xm15KPAGMNDe5w3GmCwRiQPmYh1ue80YM9VlmxDg38AooAL4ozHmA3f7quPr838dBsGFf4DF06HbJdD/FqcrUrUICw5kWNc4hnWNA+BYWTnrsguslkzmAd5Lzeb1H63bvOIjQ0jpGEv/TjH07xTDOR1i9OS/arTq+pv7hYgsBN6xv78BWFDbBiISCMwALgGygVUiMt8Ys8lltTuBg8aYbiJyI/CEve8SrE4r+9iTqz8C+40xPew7/VudZl9Nx4gHYPtSWPAQdBwC8d2crkjVUWhQIOcmteLcpFZMvci6P2ZLTiE/7TxkTbsOsmhzDgABAj1aR9G/UywDOsXQv1MsXeIjtEsZ1SjU50T9NcAIrHMq3xpj5p1m/WHAX4wxl9rfPwxgjPmHyzoL7XV+FJEgYB+QYOyiRGQyMKhaS2UX0MsYU1zt+WrdV00a1eGvSgW74cURENPZOnEfEu50RcpDDh05TtquQ6zZeYifdh4kbdchCkvKAGgZFkRKp1j6d7RaMykdY/TcjHKMJw5/YYz5APigHs/bHtjl8n02MMTdOsaYMhEpAOKAvJp2KCKVZ6j/KiKjgO3AVGNMTn331WhFt4eJ/4F3b4Y5t1gDewWFOl2V8oCY8BBG9UxkVM9EACoqDJl5RXbIWEHzryXbqLD/TerUKpw+7VvSp300fdpF07d9NLERGjTKWbWGiogUAjX9py+AMca0rG3zGuZV31dd1nEVBHQAvjfG/FZEfgs8BdxW132JyBRgCkCnTo30LvVe42HCv2D+VHj/Drj+dQjUy1abmoAAoVtiFN0So7h+UEcAio6VsS77EGm7DrFhdwEbdh9mwfp9J7ZpH9PCCpp20fTpYIVNQpT+06F8p9ZQMcZEncG+s4GOLt93APa4WSfbPmQVDRyoZZ/5wBGg8tDb+1jnUuq8L2PMTGAmWIe/6vF6/MuA26D0KHz+EHw4Ba55BQK0f6qmLjI0iOFd4xneNf7EvIIjpWzcU8D63QVs2HOYjbsLWLgx58TyNi3DTm7RdIgmMSpUrzhTXuHNS0xWAd1FJBnYDdyIdVe+q/nAJOBH4FpgSW3nQIwxRkQ+wbryawkwGqg88V+vfTUJQ6ZA2VH46s8Q3AIm/FtHjGyGosODGd4tnuHdqoKmsKSUTXsOs353ARvtr4vT91P5FxEfGUrvtlGc1bYlvdpG0bttS7omRBKsA5ipM1TnE/UN2rnIeOBZrEuKZxlj/iYi04FUY8x8EQnD6vW4P1ar4kZjTKa9bRbQEggBDgFjjDGbRKSzvU0MVh9kdxhjdta2L3ca5Yn6miz9B3zzOJx7F4x/Su9hUTUqPlbG5r2HrcNmew6zee9htuUUcbzcun8mONA63Na7bRS927Skd9uW9G4bRVykHj5TJ/NINy1NUZMJFWOs1soPz8OwqTDmMQ0WVSel5RXsyCtm897DbN5byOa9h0nfd5icw8dOrJMQFWoFTJsoO2ha0iUhQls1zZhHrv5SfkzEGte+9Cj8+G8IibBulFTqNIIDA+jROooeraOYmFI1P7/oGOn7Ck+ETfq+w8z+Pv9EqyYkMIAuCRF0bx1Fj8RIureOomebKDq1CidQ76dp1jRUmgoRGPd/1jmWb56wzrGc9z9OV6UaqbjIUEZ0C2WEy3ma0vIKMnOLSd93mE32obOfdh7kk7VV19+EBgXQNSGSHq0j6dEmih6JVmB1iG2hN282ExoqTUlAAFzxPJSWwKK/QHA4DLnH6apUExEcGEDPNlaLZGJK+xPzi4+VsW1/EVtzCtmWU8jWnCJW7DjAR2lVYdMiOJBuiZF2q8j62i0xkvYxGjZNjYZKUxMQCFe9CGUl8PnvrBbLgNudrko1YRGhQaR0tO7yd3W4pJRtOUVsyyc7p7wAABgpSURBVClkS04h23KK+G5bLh+syT6xTovgQOswWmIk3U5MUXSOC9dzNo2UnqhvCifqa1J2zLrrPmMxXP0y9LvO6YqUAqzuaLbmFJGx355yi9i+v4jdh46eWCc4UOgcVz1sIumaEElYsN6P5TS9+suNJh0qYJ24f/s6+PkH66773lc4XZFSbhUfK2N7rhU02+zA2b6/iKz84hNd04hAh9gWdE+MomtCBF0TIulqh00r7aLGZzRU3GjyoQJwrBDevBr2/AQ3vQPdL3G6IqXq5VhZOVl5R+ywKTzRwtmRV3xijBqAmPBgK2Qqw8YOnI6xLQjSQ2kepaHiRrMIFYCjh+CNCZC7BW5+D7pc4HRFSp2x8grDnkNH2Z5bxPbcYuvrfutxXlHVfTaVh9Jcw6ZLQgRdEiJ1qOcG0lBxo9mECkBxPrx2GRzaCbfNg07VO4xWqukoOFpK5ilhU8TP+Ucoq6j6zIuPDKFLfGXIRNAlPpLkhAg6tdILBWqjoeJGswoVgMIcmD0OinNh0nxo19/pipTyqdLyCnYdOML23GIyc4vIzC0mM8/6ml98/MR6QQFCp1bhJ1o0XeIjSI63HsdHhjT7zjg1VNxodqECUJBtBUvJYbj+DT0UppSt4EjpiYA58TW3mB35xRx3OXcTFRZ0UtC4Ts1lGGgNFTeaZagAHMyC/94AeVvh4kdh+K+1rzCl3KioMOw+dJTMvJNbN1l5R066DBqgdctQO2BcQichgo6x4YQENZ3DaRoqbjTbUAHrqrCPfgWb58PZV8NEu88wpVSdHT1ezs8HitmRW0xmXjE7XKYDLofTAgOEjrEtTgROckIEyXERJMWH0y668fUqoB1KqlOFRlmHv5Y9A0v+CrnpcMNbENfV6cqUajRahATSq01LerU5dRDcQ0eOnxQymXlW+CzPPMDR0vIT64UGBdA5LpykuKrDaEn218Y4mJq2VJprS8VVxmL44E4wFXD1K9BjjNMVKdVkGWPYd7iEHXnFZOUdISvfOneTlV/MzvwjJ3qCBggPCTwRNknxVvB0SYggKS6CVhHOXTCgh7/c0FBxcTAL5twK+zZY3eaPfFBHkVTKxyrvvdmRZ4VMZSsnK6+YXQePUu5yOXRUWJAVNnGVLZuq1k5MuHd7F9BQcUNDpZrjR+DTB2DdHOg53uqYMiza6aqUUliXQ2cfPEqWfSgtyyV49hw6ikveEBMeTOe4CJLjwk8cSqsMH0/c8Kmh4oaGSg2MgZUzYeEfIDYJbngbEns5XZVSqhbHysrZdeDIicNplS2drLwj7Ck4iuvHfKuIEJLiwrn3gq6MObtNg55PT9SruhOxxmBp3QfenwSvjIYr/wNnTXS6MqWUG6FBgXRLjKJbYtQpy0pKy9l54MiJw2iVoRPgpfMx2lLRlop7h/fAnNtgd6o1iuRFj1jjtSilmrXaWip6Jla517Id3LEABk62Lj1++1o4csDpqpRSfkxDRdUuKBSueM4apjhrGcy8APauc7oqpZSf0lBRdTNwEtzxOZSXwatjYO0cpytSSvkhDRVVdx0GwT3fQPsBMG8KvHOT1ZW+UkrZNFRU/UQmwu0fWx1RZn4N/x4M3z0NZcdPu6lSqunTUFH1FxgM5z0A962AbqNh8aPw4nmw4zunK1NKOUxDRTVcTCe48W24aQ6UHYXXL4cPp0DRfqcrU0o5RENFnbmeY+FXK6z+wjZ8CP8aBCtfhory02+rlGpSNFSUZ4SEw+hH4Jc/QNt+sOBBeOVi2POT05UppXxIQ0V5VkIPmPSJ1YV+QTbMvBA+exCOHnK6MqWUD2ioKM8TgX7XwdRVMPhuSH0V/n0urHsPmnG3QEo1BxoqyntaxMD4J+HuJRDdAT68G16/AnK3Ol2ZUspLNFSU97XrD3ctgsuehn3r4IXhsOhRPSSmVBOkoaJ8IyAQzr0TpqZCn2tg2dPwTB8rXIrznK5OKeUhGirKtyIT4eqX4N5l0P1iq/fjZ/rAFw9bXe0rpRo1DRXljDZ94brX4L6VcPZVsOIleO4c+OQBOJjldHVKqQbSUFHOSugBV70Av1kD/W+FtLfh+QEw7149oa9UI+TVUBGRsSKyRUQyRGRaDctDRWSOvXyFiCTZ8+NEZKmIFInIv6tt87W9zzR7SrTnTxaRXJf5d3nztSkPi02Cy5+B+9fCkHth08cwYzC8N0nHb1GqEfFaqIhIIDADGAecBdwkImdVW+1O4KAxphvwDPCEPb8EeAR40M3ubzHGpNiTa0dTc1zmv+KxF6N8p2U7GPt3eGA9jPwtbF8CL42E/94Au1Y5XZ1S6jS82VIZDGQYYzKNMceBd4GJ1daZCLxuP54LjBYRMcYUG2OWYYWLao4i4mH0n61wufBPsGsFvHoxvD7B6g1Zb6JUyi95M1TaA7tcvs+259W4jjGmDCgA4uqw79n2Ia5HRERc5l8jIutEZK6IdKxpQxGZIiKpIpKam5tb5xejHNIiBi54CB7YAGMeg9x0qzfkV8dA2jtw/IjTFSqlXHgzVKSGedX/vazLOtXdYozpC4y0p9vs+Z8AScaYfsAiqlpAJ+/cmJnGmEHGmEEJCQmneSrlN0IjYfiv4f51MP4pOJIHH90L/+xpXTGWvVpbL0r5AW+GSjbg2lroAFS/EeHEOiISBEQDB2rbqTFmt/21EPgv1mE2jDH5xphj9movAwPPsH7lj4LDrP7Efr0GJi+AXpfB2nfhlYusO/V/nAHF+U5XqVSz5c1QWQV0F5FkEQkBbgTmV1tnPjDJfnwtsMQY9/9uikiQiMTbj4OBy4EN9vdtXVadAGz2yKtQ/kkEkkbAVS/Cg1vh8mchOBwW/sFqvbx3O2z7Ssd0UcrHgry1Y2NMmYhMBRYCgcAsY8xGEZkOpBpj5gOvAm+KSAZWC+XGyu1FJAtoCYSIyJXAGOBnYKEdKIFYh7letjf5jYhMAMrsfU321mtTfiasJQy6w5pyNln3uqx9x7osOaodpNxs3QPTKtnpSpVq8qSWhkGTN2jQIJOamup0Gcobyo7D1s/hp7cgYxGYCkgaCf1vg7MmQHALpytUqtESkdXGmEE1LtNQ0VBp8gp2Wy2Xn96CgzsgNBr6XG11D9N5BAR6rcGuVJOkoeKGhkozU1EBO3+ANW9ah8bKjkKLWOg5HnpdDl0v1BaMUnWgoeKGhkozdvwIbF8Mmz+1DpOVFEBwhNVzcq8roMcYCIt2ukql/FJtoaLtftU8hYRD7yusqbwUsr6zAib9U6sVExAMXS6wlvccb3XZr5Q6LW2paEtFuaqogN2psHm+FTIHdwACnYZB78utw2SxnZ2uUilH6eEvNzRUVK2MgZyNVutl86eQs96a36af1YLpehG0TdET/arZ0VBxQ0NF1cuBHXbAfAK7VgLGupIseSQkXwBdRkF8d+vGTKWaMA0VNzRUVIMV58GObyHza2s69LM1P6qddS6myygraFq2db8PpRopPVGvlKdFxFv3uvS52vr+wA7Y8Y0VMNu+tO6LAYjvaQVMlwsg6Ty9okw1edpS0ZaK8rSKCsjZYAXMjm/g5x+g9AhIALQbUBUy7QdZV6Ep1cjo4S83NFSUT5Qdg+xVkGm3ZHavBlMOAUHQ9hzoOBQ62ZNeuqwaAQ0VNzRUlCNKDlutl13LYecKK2TK7VEbYpOrAqbjUIjvAQHe7ExcqfrTcypK+ZOwltBzrDWB1ZLZuxZ2LreGTd72VdU5mbAY6DikKmjaDbDGlFHKT2moKOW0oFDoONiawLo/Jn+73ZKpDJqF1rKAYGiXYgVN+4HQfgDEdNbLmJXf0FBRyt+IQHw3a+p/qzWvON8Kl8pDZitnQvlxa1mLWGjX32rFtOtvTS3badAoR2ioKNUYRMRBr/HWBNZ4Mfs3wZ41sOcn2P0TLHvGugAAILJ1VcBUhk1kgnP1q2ZDQ0WpxigoxDoM1i6lal7pUdi3wSVo1sDWhYB9MU50R3sbO2zanGOFlVIepKGiVFMR3AI6nmtNlY4Vwt51VUGz5yerm5lKUe2gTR9o09eaWveFVl30ijPVYBoqSjVloVGQNMKaKh09CHvSrBs09623WjcZi6sOnQVHQOuzXcKmHySepTdqqjrRUFGquWkRa41y2fXCqnmlJZCb7hI062H9B5A6y15BIK6bHTJ9rBZNYm+I7qAXBKiTaKgopax7X6qfozEGDu20AqYybHanwsYPq9YJiYKEnpDYCxJ6V33Vq8+aLQ0VpVTNRKwByWI7WwOUVTp6yBpnJncz5G6B/ZutCwJ+eqtqndDok8MmoafVsolqq2HTxGmoKKXqp0XMqedpwLqXJnezFTK56bA/HdI/gzVvVK0TFg0Jveypp9WLc3x368o0vTigSdBQUUp5RkQcRJxndfHvqijXDpt0K2xy060r0Na8XrVOUAv7hs+eVn9nCT2sr626arc0jYyGilLKuyITrCn5/JPnF+dD3hbI2wq5W62v2SthwwecuLdGAqxuaFyDprJ1E97K5y9FnZ6GilLKGRFxEDEcOg8/ef7xI5CfYYVM3lbrvE3eNmvYgMrenAHC46wr0uK6Q1xX63F8d6unZ23dOEZDRSnlX0LCoW0/a3JVUW4N25y3zQqa/Ayr482MryDN5SIBBGI62mFjB01l6LTsoOduvExDRSnVOAQEWnf7t+oCPS49eVnJYTiwHfIy7LDJgPxtViecx4uq1gsKs87TxNlTK5evkYl6ZZoHaKgopRq/sJZVfZq5MgaKcqzWzYmw2W51xrllAVSUVa0bEmkFVlxXO7xcAiciXgOnjjRUlFJNlwhEtbGm5JEnLysvg4KdkJ9ptXIOZFqBs3ctbJpf1W0NQGhLl8DpWhU8sckaONVoqCilmqfAoKrDaVx88rLyUqs3gfztVuDk26GzezVsnAemomrdkCholVQVMq26QCv7a1S7ZncOR0NFKaWqCwyuOu9SXdlx64KBAzusoDmQCQd3WL0MpC+AilKX/YRCbFJVyLiGTnRHawiDJkZDRSml6iMoxLqiLL77qcsqyqEguypoDmTa4bMDdnwLpUeq1pUAq0PO2CR7Sq4KoNgkq+PPRkhDRSmlPCUgsKq/NC48eVnlRQOVQXNwBxzMsqb0BXAk7+T1w6JPDZrK8GnZ3jp854f8syqllGpqXC8aqH7DJ1gDqlWGzMEsO3iyrN6h0z87+bBaQJB1+OyUwLGnsGivvxx3NFSUUsofhEZVjcBZXUU5HN5dFTQHd8DBn63HGz+CowdOXr9F7KlBUzm17ODVVo6GilJK+buAQIjpZE1ccOrykoKqkHGd9q6DzZ/W0MrpABc9An2v9XipXg0VERkLPAcEAq8YYx6vtjwUeAMYCOQDNxhjskQkDpgLnAu8ZoyZ6rLN10Bb4Kg9a4wxZr+7fXnx5SmllH8Ii665axuwWzl7Tg2ciHivlOK1UBGRQGAGcAmQDawSkfnGmE0uq90JHDTGdBORG4EngBuAEuARoI89VXeLMSa12jx3+1JKqeYrINDqCy2m46k3gHrj6by478FAhjEm0xhzHHgXmFhtnYlA5aAKc4HRIiLGmGJjzDKscKmrGvfV8PKVUkrVlzdDpT2wy+X7bHtejesYY8qAAiCuDvueLSJpIvKIS3DUaV8iMkVEUkUkNTc3tz6vRyml1Gl4M1RqaiWYBqxT3S3GmL7ASHu6rT77MsbMNMYMMsYMSkhIOM1TKaWUqg9vhko20NHl+w7AHnfriEgQEA1UuzbuZMaY3fbXQuC/WIfZGrQvpZRSnuXNUFkFdBeRZBEJAW4E5ldbZz4wyX58LbDEGOO2pSIiQSISbz8OBi4HNjRkX0oppTzPa1d/GWPKRGQqsBDrkuJZxpiNIjIdSDXGzAdeBd4UkQysVsWNlduLSBbQEggRkSuBMcDPwEI7UAKBRcDL9iZu96WUUso3pDn/Mz9o0CCTmlr9ymSllFK1EZHVxphBNS1rXh39K6WU8qpm3VIRkVysQ2oNEQ/knXYt39O66kfrqj9/rU3rqp8zqauzMabGy2ebdaicCRFJddf8c5LWVT9aV/35a21aV/14qy49/KWUUspjNFSUUkp5jIZKw810ugA3tK760brqz19r07rqxyt16TkVpZRSHqMtFaWUUh6joaKUUspjNFROQ0TGisgWEckQkWk1LA8VkTn28hUikuSDmjqKyFIR2SwiG0Xk/hrWGSUiBfYQAWki8mdv12U/b5aIrLef85TuCsTyvP1+rRORAT6oqafL+5AmIodF5IFq6/js/RKRWSKyX0Q2uMxrJSJficg2+2usm20n2etsE5FJNa3jwZqeFJF0++c0T0Ri3Gxb68/cS7X9RUR2u/y8xrvZtta/Xy/UNcelpiwRSXOzrVfeM3efDT79/TLG6ORmwupfbDvQBQgB1gJnVVvnV8CL9uMbgTk+qKstMMB+HAVsraGuUcCnDrxnWUB8LcvHA59jDVUwFFjhwM90H9bNW468X8D5wABgg8u8/wOm2Y+nAU/UsF0rINP+Gms/jvViTWOAIPvxEzXVVJefuZdq+wvwYB1+1rX+/Xq6rmrL/wn82ZfvmbvPBl/+fmlLpXYNHr3Sm0UZY/YaY9bYjwuBzZw6AJq/mgi8YSzLgRgRaevD5x8NbDfGNLQnhTNmjPmWU4dlcP09eh24soZNLwW+MsYcMMYcBL4CxnqrJmPMl8Ya8A5gOdbwFT7n5v2qi7r8/XqlLvsz4HrgHU89Xx1rcvfZ4LPfLw2V2nlz9EqPsA+39QdW1LB4mIisFZHPReRsH5VkgC9FZLWITKlheV3eU2+6Efd/6E68X5VaG2P2gvXBACTWsI6T790vsFqYNTndz9xbptqH5ma5OZzj5Ps1Esgxxmxzs9zr71m1zwaf/X5pqNTOW6NXeoSIRAIfAA8YYw5XW7wG6xDPOcC/gI98URMwwhgzABgH3Cci51db7uT7FQJMAN6vYbFT71d9OPLeicgfgTLgbTernO5n7g0vAF2BFGAv1qGm6hz7XQNuovZWilffs9N8NrjdrIZ59X6/NFRq55XRKz1BrDFlPgDeNsZ8WH25MeawMabIfrwACBZ7gDNvMsbssb/uB+ZRNTJnpbq8p94yDlhjjMmpvsCp98tFTuVhQPvr/hrW8fl7Z5+svRxrGO8aP2Dq8DP3OGNMjjGm3BhTgTWmUk3P6cjvmv05cDUwx9063nzP3Hw2+Oz3S0Oldh4fvdIT7OO1rwKbjTFPu1mnTeW5HREZjPWzzvdyXREiElX5GOtE74Zqq80HbhfLUKCgslnuA27/e3Ti/arG9fdoEvBxDessBMaISKx9uGeMPc8rRGQs8HtggjHmiJt16vIz90ZtrufhrnLznHX5+/WGi4F0Y0x2TQu9+Z7V8tngu98vT1990NQmrKuVtmJdRfJHe950rD80gDCswykZwEqgiw9qOg+rWboOSLOn8cC9wL32OlOBjVhXvCwHhvugri728621n7vy/XKtS4AZ9vu5Hhjko59jOFZIRLvMc+T9wgq2vUAp1n+Hd2Kdh1sMbLO/trLXHQS84rLtL+zftQzgDi/XlIF1jL3yd6zyKsd2wILafuY+eL/etH9/1mF9YLatXpv9/Sl/v96sy57/WuXvlcu6PnnPavls8Nnvl3bTopRSymP08JdSSimP0VBRSinlMRoqSimlPEZDRSmllMdoqCillPIYDRWlGimxelb+1Ok6lHKloaKUUspjNFSU8jIRuVVEVtpjZ7wkIoEiUiQi/xSRNSKyWEQS7HVTRGS5VI1hEmvP7yYii+wOL9eISFd795EiMlescU/e9nYP2UqdjoaKUl4kIr2BG7A6EEwByoFbgAisfsgGAN8A/2tv8gbwe2NMP6w7xivnvw3MMFaHl8Ox7uQGqxfaB7DGzOgCjPD6i1KqFkFOF6BUEzcaGAisshsRLbA686ugqsPBt4APRSQaiDHGfGPPfx143+4nqr0xZh6AMaYEwN7fSmP3MSXWKINJwDLvvyylaqahopR3CfC6Mebhk2aKPFJtvdr6S6rtkNYxl8fl6N+0cpge/lLKuxYD14pIIpwYK7wz1t/etfY6NwPLjDEFwEERGWnPvw34xljjYWSLyJX2PkJFJNynr0KpOtL/apTyImPMJhH5E9YofwFYPdreBxQDZ4vIaqzRQm+wN5kEvGiHRiZwhz3/NuAlEZlu7+M6H74MpepMeylWygEiUmSMiXS6DqU8TQ9/KaWU8hhtqSillPIYbakopZTyGA0VpZRSHqOhopRSymM0VJRSSnmMhopSSimP+f/R+5KqnmykZwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "from IPython import display\n",
    "import torch.utils.data as Data\n",
    "\n",
    "num_inputs = 500\n",
    "num_examples = 10000\n",
    "true_w = torch.ones(500,1)*0.0056\n",
    "true_b = 0.028\n",
    "#随机生成的数据样本\n",
    "features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)#行*列=10000*500\n",
    "labels = torch.mm(features,true_w) + true_b\n",
    "labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float) #扰动项\n",
    "#训练集和测试集上的样本&标签数----真实的特征和样本\n",
    "trainfeatures = features[:7000]\n",
    "trainlabels = labels[:7000]\n",
    "testfeatures = features[7000:]  \n",
    "testlabels = labels[7000:]\n",
    "print(trainfeatures.shape,trainlabels.shape,testfeatures.shape,testlabels.shape)\n",
    "\n",
    "#获得数据迭代器\n",
    "batch_size = 50 # 设置小批量大小\n",
    "def load_array(data_arrays, batch_size, is_train=True):  #自定义函数\n",
    "    #\"\"\"构造一个PyTorch数据迭代器。\"\"\"\n",
    "    dataset = Data.TensorDataset(*data_arrays)#features 和 labels作为list传入，得到PyTorch的一个数据集\n",
    "    return Data.DataLoader(dataset, batch_size, shuffle=is_train,num_workers=0)#返回的是实例化后的DataLoader\n",
    "train_iter = load_array([trainfeatures,trainlabels],batch_size)\n",
    "test_iter = load_array([testfeatures,testlabels],batch_size)\n",
    "\n",
    "#定义超参数\n",
    "num_inputs=500\n",
    "num_hiddens = 256\n",
    "num_outputs = 1\n",
    "#定义参数\n",
    "W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs,num_hiddens)), dtype=torch.float32)  \n",
    "b1 = torch.zeros(1, dtype=torch.float32)  \n",
    "W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_outputs)), dtype=torch.float32)  \n",
    "b2 = torch.zeros(1, dtype=torch.float32)  \n",
    "params = [W1,b1,W2,b2]\n",
    "for param in params:\n",
    "    param.requires_grad_(requires_grad = True)#设置为true，追踪并记录所有在计算图上的操作（正向积累）\n",
    "\n",
    "def relu(x):  \n",
    "    x = torch.max(input=x,other=torch.tensor(0.0))  \n",
    "    return x \n",
    "\n",
    "#定义模型  \n",
    "def net(X):  \n",
    "    X = X.view((-1,num_inputs)) #将数据进行展平，对于空间结构的数据生效\n",
    "    H = relu(torch.matmul(X,W1)+b1)  \n",
    "    return torch.matmul(H,W2)+b2  \n",
    "loss = torch.nn.MSELoss()\n",
    "def SGD (params,lr,batch_size):\n",
    "    for param in params:\n",
    "        param.data -= lr * param.grad/batch_size\n",
    "#记录列表（list），存储训练集和测试集上经过每一轮次，loss的变化\n",
    "def train (net,train_iter,test_iter,loss,num_epochs,batch_size,params = None,lr=None,optimizer=None):\n",
    "    train_loss=[]\n",
    "    test_loss=[]\n",
    "    for epoch in range(num_epochs):#外循环控制循环轮次\n",
    "        #step1在训练集上，进行小批量梯度下降更新参数\n",
    "        for X,y in train_iter:#内循环控制训练批次\n",
    "            y_hat = net(X)\n",
    "            l = loss(y_hat,y)#l.size = torch.Size([]),即说明loss为表示*标量*的tensor`\n",
    "            #梯度清零\n",
    "            if optimizer is not None:\n",
    "                optimizer.zero_grad()\n",
    "            elif params is not None and params[0].grad is not None:\n",
    "                for param in params:\n",
    "                    param.grad.data.zero_()\n",
    "            l.backward()\n",
    "            if optimizer is None:\n",
    "                SGD(params,lr,batch_size)\n",
    "            else:\n",
    "                optimizer.step()\n",
    "        #step2 每经过一个轮次的训练， 记录训练集和测试集上的loss\n",
    "        train_labels = trainlabels.view(-1,1)  \n",
    "        test_labels = testlabels.view(-1,1) \n",
    "        train_loss.append((loss(net(trainfeatures),train_labels)).item())#！注意要取平均值\n",
    "        test_loss.append((loss(net(testfeatures),test_labels)).item())\n",
    "        print(\"epoch %d,train_loss %.6f,test_loss %.6f\"%(epoch+1,train_loss[epoch],test_loss[epoch])) \n",
    "    return train_loss, test_loss\n",
    "\n",
    "lr=0.01\n",
    "num_epochs = 20\n",
    "#batch_size、params epc已经定义\n",
    "train_loss, test_loss = train (net,train_iter,test_iter,loss,num_epochs,batch_size,params,lr)#每一给optimizer,默认None\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "x=np.linspace(0,len(train_loss),len(train_loss))\n",
    "plt.plot(x,train_loss,label=\"train_loss\",linewidth=1.5)\n",
    "plt.plot(x,test_loss,label=\"test_loss\",linewidth=1.5)\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.ylabel(\"loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([14000, 200]) torch.Size([14000, 1]) torch.Size([6000, 200]) torch.Size([6000, 1])\n",
      "epoch 1,train_loss 32.818668,test_loss 33.047234\n",
      "epoch 2,train_loss 29.763063,test_loss 30.041527\n",
      "epoch 3,train_loss 26.843542,test_loss 27.103823\n",
      "epoch 4,train_loss 24.333836,test_loss 24.498997\n",
      "epoch 5,train_loss 22.103399,test_loss 22.288319\n",
      "epoch 6,train_loss 20.242859,test_loss 20.459768\n",
      "epoch 7,train_loss 18.379824,test_loss 18.508512\n",
      "epoch 8,train_loss 16.787054,test_loss 16.945805\n",
      "epoch 9,train_loss 15.316413,test_loss 15.494952\n",
      "epoch 10,train_loss 13.986809,test_loss 14.163454\n",
      "epoch 11,train_loss 12.796457,test_loss 12.986771\n",
      "epoch 12,train_loss 11.763365,test_loss 11.954306\n",
      "epoch 13,train_loss 10.748314,test_loss 10.919194\n",
      "epoch 14,train_loss 9.865685,test_loss 10.024225\n",
      "epoch 15,train_loss 9.078724,test_loss 9.221390\n",
      "epoch 16,train_loss 8.355811,test_loss 8.472768\n",
      "epoch 17,train_loss 7.732056,test_loss 7.825146\n",
      "epoch 18,train_loss 7.099344,test_loss 7.169745\n",
      "epoch 19,train_loss 6.561232,test_loss 6.615660\n",
      "epoch 20,train_loss 6.060052,test_loss 6.113162\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3wUdf7H8ddnNxUIoXcp0kRaAqFLR5pIL4ogIEgRFFQ88XenIKd3nnq2OwT1Dk8RFQVBTlFBBBGkGCD0EkINhBZII6Tu9/fHLl7EhCSQ3Umyn+fjkQeb2ZmdN5PsO7Mzs98VYwxKKaW8h83qAEoppTxLi18ppbyMFr9SSnkZLX6llPIyWvxKKeVlfKwOkBcVKlQwtWvXtjqGUkoVKdu3b79ojKl4/fQiUfy1a9cmPDzc6hhKKVWkiMiJ7KbroR6llPIyWvxKKeVltPiVUsrLFIlj/Eqp4ic9PZ3o6GhSUlKsjlLkBQQEUKNGDXx9ffM0vxa/UsoS0dHRBAUFUbt2bUTE6jhFljGG2NhYoqOjqVOnTp6W0UM9SilLpKSkUL58eS39WyQilC9fPl+vnLT4lVKW0dIvGPndjsW7+E9tg42vW51CKaUKleJd/HuWwvdzIHKN1UmUUqrQKNbFP9/vQU741MYsnwyJZ62Oo5QqROLi4nj77bfzvVzfvn2Ji4vL93Jjx45l6dKl+V7OHYp18VevWI6HrkwlIyURlk8Ch8PqSEqpQiKn4s/MzLzhcqtWraJMmTLuiuURxfpyzv7Nq/HjoTY8u+tBXjr6Hvz8Jtz1uNWxlFLXef6/+9h/JqFAH/POaqWZfW/jHO+fNWsWUVFRhISE4OvrS6lSpahatSoRERHs37+fgQMHcurUKVJSUpg+fToTJ04E/jd2WFJSEn369OGuu+7i559/pnr16nz55ZcEBgbmmm3t2rXMnDmTjIwMWrVqxfz58/H392fWrFmsXLkSHx8fevbsyauvvsrnn3/O888/j91uJzg4mA0bNtzytinWe/wAzw9ozObgvqy1tcf88AJE62BvSil46aWXqFu3LhEREbzyyits27aNF198kf379wOwcOFCtm/fTnh4OG+99RaxsbG/e4zIyEimTp3Kvn37KFOmDMuWLct1vSkpKYwdO5YlS5awZ88eMjIymD9/PpcuXWL58uXs27eP3bt386c//QmAuXPn8t1337Fr1y5WrlxZIP/3Yr3HD1DK34c37gtl3IKHWFsiinJLH0Im/wQBwVZHU0q53GjP3FNat279mzdAvfXWWyxfvhyAU6dOERkZSfny5X+zTJ06dQgJCQGgZcuWHD9+PNf1HDp0iDp16tCgQQMAxowZw7x585g2bRoBAQFMmDCBe+65h379+gHQoUMHxo4dy/Dhwxk8eHBB/FeL/x4/QGjNskzoEcLDV6Zg4qPhqyfAGKtjKaUKkZIlS/56e/369Xz//fds3ryZXbt2ERoamu0bpPz9/X+9bbfbycjIyHU9Jofu8fHxYdu2bQwZMoQVK1bQu3dvABYsWMALL7zAqVOnCAkJyfaVR355RfEDTOlSD5/abflH5lDYuxQiPrY6klLKQkFBQSQmJmZ7X3x8PGXLlqVEiRIcPHiQLVu2FNh677jjDo4fP86RI0cAWLRoEZ07dyYpKYn4+Hj69u3LG2+8QUREBABRUVG0adOGuXPnUqFCBU6dOnXLGYr9oZ5r7DbhjREh9H1jEN18D9Bk1VPIba2hQn2roymlLFC+fHk6dOhAkyZNCAwMpHLlyr/e17t3bxYsWECzZs1o2LAhbdu2LbD1BgQE8P777zNs2LBfT+5OnjyZS5cuMWDAAFJSUjDG8PrrzjefPvXUU0RGRmKMoXv37jRv3vyWM0hOLzsKk7CwMFNQn8D19e4Y5n78PetL/R+BFWrDhO/Bxz/X5ZRSBevAgQM0atTI6hjFRnbbU0S2G2PCrp/Xaw71XHNPs6p0atmMx64+DGd3O9/Zq5RSXsTrih9gTv/GHCnbic/sfWHL23D4O6sjKaWKialTpxISEvKbr/fff9/qWL/hNcf4syrp78MbI0K4f/4I2pc6SPUVU5ApP0NQFaujKaWKuHnz5lkdIVdeuccP0Py2Mkzr2YQxiVPITE2GLybqkA5KKa/gtcUPMKlTXSrWacqc9Afh2I+w6Q2rIymllNt5dfHbbcLrI0L4r607P/l1dA7pcOoXq2MppZRbeXXxA1QNDuRvQ5sxNeFBEnwrwbLxkBJvdSyllHIbry9+gN5NqtK31R2MS5rkGtLhcR3SQali7mbH4wd44403SE5OvuE8tWvX5uLFizf1+O6mxe/y3L13ElculAW2EbB3GUQstjqSUsqN3F38hZnbLucUkQBgA+DvWs9SY8xsEakDfAqUA3YAo40xae7KkVcl/Hx4875Qhs5PomvQfhquegq5rY0O6aCUJ3wzC87uKdjHrNIU+ryU491Zx+O/++67qVSpEp999hmpqakMGjSI559/nitXrjB8+HCio6PJzMzk2Wef5dy5c5w5c4auXbtSoUIF1q1bl2uU1157jYULFwIwYcIEZsyYke1jjxgxItsx+QuaO6/jTwW6GWOSRMQX2Cgi3wBPAK8bYz4VkQXAeGC+G3PkWdMawTzRsxFjvpnAhqA/4b90HExYq0M6KFUMvfTSS+zdu5eIiAhWr17N0qVL2bZtG8YY+vfvz4YNG7hw4QLVqlXj66+/BpyDtwUHB/Paa6+xbt06KlSokOt6tm/fzvvvv8/WrVsxxtCmTRs6d+7M0aNHf/fY18bkP3jwICJyUx/xmBduK37jHAQoyfWtr+vLAN2Aka7pHwBzKCTFD/Bwx9vZEHmBGScmMv/sy7Bm9g33GpRSBcDi59jq1atZvXo1oaGhACQlJREZGUnHjh2ZOXMmTz/9NP369aNjx475fuyNGzcyaNCgX4d9Hjx4MD/99BO9e/f+3WNnZGRkOyZ/QXPrMX4RsYtIBHAeWANEAXHGmGuDVkcD1d2ZIb9sNuHvw0LY7NuKlQH3wtb5cHCV1bGUUm5kjOGZZ54hIiKCiIgIjhw5wvjx42nQoAHbt2+nadOmPPPMM8ydO/emHjs72T12TmPyFzS3Fr8xJtMYEwLUAFoD2Q3Fl+1WEZGJIhIuIuEXLlxwZ8zfqRIcwN+GNGNm3FBiSjSEFVMg7qRHMyil3CvrePy9evVi4cKFJCU5D1KcPn2a8+fPc+bMGUqUKMGoUaOYOXMmO3bs+N2yuenUqRMrVqwgOTmZK1eusHz5cjp27JjtY+c0Jn9B88hYPcaYOBFZD7QFyoiIj2uvvwZwJodl3gXeBeewzJ7ImVWvxlUY2qYu922bxNqSz+KzdDyMWwV2X09HUUq5Qdbx+Pv06cPIkSNp164dAKVKleKjjz7iyJEjPPXUU9hsNnx9fZk/33lUeuLEifTp04eqVavmenK3RYsWjB07ltatWwPOk7uhoaF89913v3vsxMTEbMfkL2huG49fRCoC6a7SDwRWA38DxgDLspzc3W2MueE1VQU5Hn9+pKRnMnDeJkLif+Al8zp0mA535/+lnlLq93Q8/oJVWMbjrwqsE5HdwC/AGmPMV8DTwBMicgQoD/zbjRluSYCvnXkPtGBlZltWB/aFTW9C5BqrYyml1C1x51U9u4HQbKYfxXm8v0ioW7EUfxnUlEeXDOencoeptHwSTN4IpatZHU0pVQi0adOG1NTU30xbtGgRTZs2tShR7rxyPP78GhhanS1H63Jf+GTWlHwO+9LxMOa/YNfNp9StMMYgIlbHuCVbt261OkKOVw7lRIdsyKM5/RvjV7khszMnwMmf4Ue9tl+pWxEQEEBsbGy+S0v9ljGG2NhYAgIC8ryM7rLmUYCvnX+ObEH/fybTveRBumx4FanVAep2tTqaUkVSjRo1iI6OxtOXaxdHAQEB1KhRI8/za/HnQ71KpXhxUBMeWXIfP5WNpMIXD8PkTRBU2epoShU5vr6+1KlTx+oYXkkP9eTToNAa9A+rz8j4KWSmJMIXE8CRaXUspZTKMy3+mzCnf2OkYiNecIyDYxtgQ8GPnqeUUu6ixX8TAv3szHsglE8zOrMhsBvmx5fg2E9Wx1JKqTzR4r9J9SoF8cLApky5/ACX/WvAsgmQpCeplFKFnxb/LRjSsgb3hNVnVMIUHMmXYPkkcDisjqWUUjekxX+Lnu/fhIyKjXnJjIGotbDpDasjKaXUDWnx36JAPztvP9CCRRnd+TmgE+aHF+DEZqtjKaVUjrT4C8C14/2T4h4k3q8KLBsPyZesjqWUUtnS4i8gQ1rWoHfLBoxOnIIj6bzzw1v0rehKqUJIi78AzR3QhNSKzXjVjIbD38LmeVZHUkqp39HiL0CBfnbmjWzB+xk92ebfHvP9bIj2/AfIKKXUjWjxF7D6lYP488CmTIgfS4JvRfh8rF7fr5QqVLT43WBoyxr0bHkHoxOnkpl0HpaMgozU3BdUSikP0OJ3k7kDGnO1QjP+6JgKp7bAysf0ZK9SqlDQ4neTEn4+LBjdkm9ox/t+98PuT2Hja1bHUkopLX53qluxFO89GMZfr/RnY0AXWDsX9q+0OpZSystp8btZ6zrleHV4COPjxnIsoBFm+SQ4E2F1LKWUF9Pi94D+zasxvXdThsc9SqIEwSf3QUKM1bGUUl5Ki99DpnSuS4/WzRie+DjpyXHO8k9LtjqWUsoLafF7iIjw5wGNqdqgJY+kTMXE7IIVk3UYZ6WUx2nxe5CP3cY/R7bgTOUuvOp4APZ/Cev/YnUspZSX0eL3sJL+Piwc24rlAYNYKd1gwyuw+zOrYymlvIgWvwUqlw7g/YfaMNsxgV32Jpgvp8GpbVbHUkp5CS1+izSsEsQ/R7dh/NXHOEc5zKcjIe6k1bGUUl7AbcUvIreJyDoROSAi+0Rkumv6HBE5LSIRrq++7spQ2HWoV4FZQzrwQPITpFy9ivl4BKQmWh1LKVXMuXOPPwN40hjTCGgLTBWRO133vW6MCXF9rXJjhkJvaMsa3Nu9Cw+nPIo5fwiWTQBHptWxlFLFmNuK3xgTY4zZ4bqdCBwAqrtrfUXZ9O71qRLah+fSH3R+gMua56yOpJQqxjxyjF9EagOhwFbXpGkisltEFopI2RyWmSgi4SISfuFC8R7PXkT46+CmHK9zPx9m9oTN/4QdH1odSylVTLm9+EWkFLAMmGGMSQDmA3WBECAG+Ht2yxlj3jXGhBljwipWrOjumJbztdt4e1QLPi33CBtNc8xXj8Oxn6yOpZQqhtxa/CLii7P0FxtjvgAwxpwzxmQaYxzAe0Brd2YoSkoH+PKvcW15zu9JjpsqOJaMhtgoq2MppYoZd17VI8C/gQPGmNeyTK+aZbZBwF53ZSiKqpUJ5B/jujAl8ykSUzNxLB4GieesjqWUKkbcucffARgNdLvu0s2XRWSPiOwGugKPuzFDkdS4WjCzHujLw2lPkH75NI4P+uvn9iqlCoyPux7YGLMRkGzu8urLN/OqS8NKnBs4hDHLM/gw9hV8P+iHjP0aSlawOppSqojTd+4WYiNa1aRPv2GMSZ1J+sWjmA/uhSuxVsdSShVxWvyF3Jj2teneZyjjUp8k/cIRzIf3QvIlq2MppYowLf4iYELH2+nYaxjjU58g43wk5sP+Wv5KqZumxV9ETO5clzY9hjIh9XEyzx3ELBoIVy9bHUspVQRp8Rch07rVJ6TrUCakziDz7H7MokFwNc7qWEqpIkaLv4iZ0aM+jTsPZWLqdDJj9jjLPyXe6lhKqSJEi7+IERFm9mxI/buGMjl1Oo6Y3ZiPhkBKgtXRlFJFhBZ/ESQizOpzBzXbDWVK6qM4Tu/ALB6qY/krpfJEi7+IEhGe7deIqm2GMjX1UcypcMxHWv5Kqdxp8RdhIsKc/o0p12oo09KmYk79AouHQ2qS1dGUUoWYFn8RJyK8MKAJQS2G8VjaIzhOboGPR0DaFaujKaUKKS3+YsBmc36Qi1/IUGakTcFx4mdX+SdbHU0pVQhp8RcTNpvwytDmSLNhPJE2GXN8I3xyH6RftTqaUqqQ0eIvRuw24e/DmpPeZBhPpk3CHNvg3PPXSz2VUllo8RczPnYbb4wIIbnRcGamTcJxfCP8uydcPm51NKVUIaHFXwz52m28dX8o8Q2HMTr1D6Rcisa81w1ObLY6mlKqENDiL6b8fGzMH9WCGi360id5DufTA53j+e9cbHU0pZTFtPiLMV+7jZeGNOW+3l3pmfQcu+2N4ctHYM1z4Mi0Op5SyiJu++hFVTiICJM616VW+ZI8sKQEf/ZbxKBNb8LFSBj8LvgHWR1RKeVhusfvJXo3qcLHkzryV3mYF81DmMPfwsLeEHfS6mhKKQ/T4vcizWqUYcW0u9hYbjDj0v5AWuwJeK8bnNpmdTSllAdp8XuZamUC+XxyO2z1e9DnynNcyvDD/Oce2LXE6mhKKQ/R4vdCpfx9eO/BMDq170C3+Oc45HMHLJ8Ia+eCw2F1PKWUm+nJXS9ltwmz721MnQolGbAykDeDPqL3T3+Hi4dh0DvgV9LqiEopN9Hi93IPtqtNzXIlmPaxPwfs1Zhx8ENkYW+4/1MIrm51PKWUG+ihHkWXhpVYNqUDS/0GMDnjD2RcjIL3ukL0dqujKaXcQItfAdCwShDLp7bnbJVO9E2eTUKGD+Y/fWHPUqujKaUKmBa/+lWloACWTGxL/cat6BL3LMf9GsCy8bDxDTDG6nhKqQKSp+IXkekiUlqc/i0iO0SkZy7L3CYi60TkgIjsE5HprunlRGSNiES6/i1bEP8RVTACfO384/5QRnZtQa9LT7I5sDN8Pxu+naXDPChVTOR1j/8hY0wC0BOoCIwDXsplmQzgSWNMI6AtMFVE7gRmAWuNMfWBta7vVSFiswkzezXkL8PCGJMwkaW+/WHrAlg6DtJTrI6nlLpFeb2qR1z/9gXeN8bsEhG50QLGmBggxnU7UUQOANWBAUAX12wfAOuBp/MXW3nC0JY1qFE2kMkf+XFCyvLk/g/gykW4bzEE6gs1pYqqvO7xbxeR1TiL/zsRCQLy/E4fEakNhAJbgcquPwrX/jhUymGZiSISLiLhFy5cyOuqVAFre3t5lj/Sga9KDubxzMdwnNwKC/tAfLTV0ZRSNymvxT8e5yGZVsaYZMAX5+GeXIlIKWAZMMN1uChPjDHvGmPCjDFhFStWzOtiyg3qVCjJF1Pac7pGXx5IfZrUSycx/7obzu23OppS6ibktfjbAYeMMXEiMgr4ExCf20Ii4ouz9BcbY75wTT4nIlVd91cFzuc/tvK0siX9+Gh8G6qF9GJg8p9IuJqKWdgLjm+0OppSKp/yWvzzgWQRaQ78ATgBfHijBVznAP4NHDDGvJblrpXAGNftMcCX+UqsLOPnY+PVYc3o17MnfZOe43RGMGbRINi33OpoSql8yGvxZxhjDM4Ts28aY94EcvsEjw7AaKCbiES4vvrivBrobhGJBO4m96uDVCEiIkztWo9nRvZkcOps9pi6mM/HwZYFVkdTSuVRXq/qSRSRZ3AWeUcRseM8zp8jY8xG/nc10PW65z2iKoz6NatGtTI9mPJBIM/zJj2+fRoSTkOP58Gm7wtUqjDL6zN0BJCK83r+szgvy3zFbalUkdCiZlk+ndqNV0s/w0eZd8PPb8HySZCRZnU0pdQN5Kn4XWW/GAgWkX5AijHmhsf4lXe4rVwJPnukI9/VmsnL6SNgz2eYxUMhJc8XcCmlPCyvQzYMB7YBw4DhwFYRGerOYKroKB3gy/vjWhMf9ihPpk3GcWwjjoV9IPGs1dGUUtnI66GeP+K8hn+MMeZBoDXwrPtiqaLGx27jhYFNuLPvZB5Km0nq+SNkvtcDLhy2OppS6jp5LX6bMSbr9fax+VhWeQkRYfxddRg9ajyjHbOJT0gg8189YP9Kq6MppbLIa3l/KyLfichYERkLfA2scl8sVZT1uLMycyY9wATfv3IgpTx8NhpWPgppV6yOppQCxORxnHURGYLz2nwBNhhjPPaunbCwMBMeHu6p1akCci4hhcc//oUO0e8yxee/UO52bEP/DdVCrY6mlFcQke3GmLDfTc9r8VtJi7/oynQY5q07wua1K3jLfz7lJQFbtz9B+8f0en+l3Cyn4r/hM09EEkUkIZuvRBHR6/VUruw24bHu9Xli4nhG+b7G6oxQ+H425sMBEH/a6nhKeaUbFr8xJsgYUzqbryBjTGlPhVRFX6va5fhsxj2sqPdX/pD+MGkntuGY315P/CplAX2trTwmuIQv80e3pHn/R+mf/hcO6olfpSyhxa88SkR4oE0t3po2nJmlX2Z+xr2YHYswCzrBmZ1Wx1PKK2jxK0s0rBLEF492Jbrl04xM+z9i4y47P9xl4xvgyPOHuymlboIWv7JMgK+dFwc1ZczI0QxyvMz3mc4TvywaAAlnrI6nVLGlxa8s17tJVT6dcQ/vVpnD0+kPk3p8G+bt9nDgv1ZHU6pY0uJXhUL1MoF8MrEdVbpMpG/qixxOKwdLRjlP/KYmWh1PqWJFi18VGj52G4/f3YAXJwxivP0vvJvpOvH7djuIWmd1PKWKDS1+Vei0vb08/53RjV/qz2BI6mxikhywaCD8d7qO869UAdDiV4VS2ZJ+vDu6JfcNGUr/jL+x0NyLY/uHzr3/I2utjqdUkabFrwotEWF42G18+XgP1tV8lCGpszmTLPDRYPhyGqTEWx1RqSJJi18VetXLBPLhQ60ZNnAw/dL+wr/MABw7Fzv3/iPXWB1PqSJHi18VCSLCyDY1WTmjB2urP8Kg1DmcvuoDi4fCiqlwNc7qiEoVGVr8qki5rVwJFk9ow9D+A7gn5QXeYyCOXZ9g3m4Lh7+zOp5SRYIWvypybDZhdLvarJzRnTVVJzMg5XlOp/jDx8Nh+RS4etnqiEoValr8qsiqVb4knz7clkH39KNPyp95h8E4di/BzGsLh76xOp5ShZYWvyrSbDbhobvq8OVj3Vhd+WHuTZlLdFoJ+OQ++GIiJF+yOqJShY4WvyoWbq9Yis8mtWNAnz70Tp7LAobh2LMMM68N7P4cisBHjCrlKW4rfhFZKCLnRWRvlmlzROS0iES4vvq6a/3K+9htwsROdfnysS58U3Ec/VLmciKjLHwxAT64Fy4csjqiUoWCO/f4/wP0zmb668aYENfXKjeuX3mpepWCWDa5Hf169aL3ldnMNRNIjY7AzO8A38/RT/tSXs9txW+M2QDoAVZlCR+7jUe61OO7x7sQVWsE7ZNeZo1PJ9j4OsxrAwe/1sM/ymtZcYx/mojsdh0KKpvTTCIyUUTCRST8woULnsynipFa5Uvyn3GteHFUV2bLVIalPkdMig98OtJ5AvjycasjKuVxYty41yMitYGvjDFNXN9XBi4CBvgzUNUY81BujxMWFmbCw8PdllN5hyupGbz1QyQf/BTJw/5reMy2FB9xIJ1mQvvHwMff6ohKFSgR2W6MCbt+ukf3+I0x54wxmcYYB/Ae0NqT61feraS/D8/0acTK6V3ZVmUkHa/8jZ/tYfDDCzC/PUT9YHVEpTzCo8UvIlWzfDsI2JvTvEq5S4PKQXw6sS1Pj+jG9MzHGZP2NLFJKbBoEHw+DhJirI6olFv5uOuBReQToAtQQUSigdlAFxEJwXmo5zgwyV3rV+pGRIRBoTXodkdlXltdhQ5bGjEj8BsePrAcW+QapOsz0HoS2N32FFHKMm49xl9Q9Bi/cre9p+P544q9XI4+xJulPyY09Reo3ATu+TvUbGt1PKVuSqE4xq9UYdWkejDLp7Rn8sAejE19iinpTxB/+QIs7AWfPQgXj1gdUakCo8WvlIvN5hzz/4eZXQgKHUi7hL/ynn0E6YfWYOa1hv/OgMSzVsdU6pbpoR6lchB+/BJzv9rPmeiTPFv6a+5N/xbx8UPaPgIdHoOAYKsjKnVDOR3q0eJX6gYcDsPXe2J45btDcPkYL5X5kvZX10NgOeg0E1pN0Ov/VaGlxa/ULUjLcLB46wn+8cMRqiYf4u9ll3NHcjgE14Ruf4Smw8BmtzqmUr+hJ3eVugV+PjbGdajDj091oWuXuxmYOJMxGf/HmfRAWD4J3ukEh1fr+D+qSNDiVyofggJ8mdmrIetndqVqaB86Xn6WmWYG8fFx8PEw+E8/iNZXp6pw0+JX6iZUCQ7gpSHN+HZGF+Ju70dY3F94xf4wKTH74V/dYclouBhpdUylsqXH+JUqANuOXeIvqw5w+NRZZgV/z8jML7FnpiCho6DjE1C2ttURlRfSk7tKuZkxhm/2nuXlbw+SGBvDC+W+oVfKN9iMA5oMgbseh8p3Wh1TeREtfqU8JD3TwafbTvLG95H4XDnLnIrr6Hn1G+wZydCgj/MVwG06MK1yPy1+pTwsKTWDDzcfZ+HGY2QkxTKr/AYGp3+NX1oc1OoAdz0B9bqDiNVRVTGlxa+URVLSM/k8/BTvbDhK7OXLTC/zMw/yFSVSzkKVZs5DQHcO0PcBqAKnxa+UxdIzHXy1+wzz10dx7FwcDwVt5RHfrwhOPgHl6kKH6dD8Pn0nsCowWvxKFRIOh2HtwfO8vf4Iu05eYliJCJ4s8RWVkg5CUFVoNw1ajgX/UlZHVUWcFr9ShYwxhi1HL/H2+iP8FHmBnv4H+GPwN9RK2A6BZZ0fBNNmEpQoZ3VUVURp8StViO2JjmfBj1Gs2htDK3sUz5dfTaP4n8C3hHPvv900CK5udUxVxGjxK1UEHL2QxDs/HuWLndHcbqJ5oeIawhLWImKD5iOgw+NQoZ7VMVURocWvVBESE3+Vf/10jE+2naRc+lmer/ADXZO/RTLTkDsHON8LULW51TFVIafFr1QRFJecxoebT/Cfn49ju3KBZ8qto3/aKnwzkqBud+cfgFod9L0AKlta/EoVYVfTMvl8+yne3XCU+MuxTC/9I6P4moC0S1CjtfMPQIPe+gdA/YYWv1LFQEamg1V7z7JgfRRRMReZUHIjk32/JiglBird6Xw3cONBYPexOqoqBLT4lSpGjDH8FHmRBT9GsS3qHMMCtvFk4FdUuHrcORJo+8cg5AHwDbA6qrKQFr9SxdSuU3G8syGKb/eeobfPTp4ptYrbrh6AUpWh7RRoMcE9OZQAAA9MSURBVEbfC+CltPiVKuaOXbzCez8dZen2U4Q59vJs8Lc0urod7P7QZDCEjYcaYXoewIto8SvlJc4npvCfTcdZtOUE1VOP8ljwBu5OX49vZjJUaQphD0HT4TokhBfQ4lfKyySmpLN0ezQfbz3JmfMXGOG/mYmB66mScgT8gpxvCAsbrx8OU4xp8SvlpYwxhJ+4zOItJ1i1N4YmmYd4LHgDHdM2YnekwW1todV459DQOjJoseLx4heRhUA/4LwxpolrWjlgCVAbOA4MN8Zczu2xtPiVKhiXr6SxbIfzVcDlizGMCtjEOP8fKJd6GkqUh9BRzrGByt1udVRVAKwo/k5AEvBhluJ/GbhkjHlJRGYBZY0xT+f2WFr8ShUsYwybj8by8daTrN53htZmL48G/UirtK3YTKbzXcGtxkP9XvqegCLMkkM9IlIb+CpL8R8CuhhjYkSkKrDeGNMwt8fR4lfKfS4mpfJ5eDSfbDtJ6qVoxgZu4AGfdZROvwClq0OzEdD8fqjYwOqoKp8KS/HHGWPKZLn/sjGmbA7LTgQmAtSsWbPliRMn3JZTKeX8gJhNURf5eOtJ1u4/Qxe280jQRpql7XC+CqjWwvkHoMkQKFne6rgqD4pc8Wele/xKedb5hBQ+Cz/FkvBTpFyKYbDfZkYHbqZG6hGMzQep38v5MZENeukJ4UIsp+L39MG7cyJSNcuhnvMeXr9SKg8qlQ5gWrf6TO1ajx0nL7N8ZzP67e5PldQoRgb8zKCoTQQd+hoTUAZpMsT5SkDfHFZkeHqP/xUgNsvJ3XLGmD/k9ji6x6+U9dIyHKw/dJ4VEadZdyCGVo7dPFhiM50dW/F1pDo/ML75/dBsOJStZXVchTVX9XwCdAEqAOeA2cAK4DOgJnASGGaMuZTbY2nxK1W4xF9N59u9MSzfeZq9R6PpY9/GmBJbaJK+2zlDrbuch4LuHAABpa0N68X0DVxKKbc4HXeVlRFnWL4zmuTzxxjss4kHAjZROf00xicAadALGg92ng/wDbQ6rlfR4ldKuZUxhv0xCazYeZovd56m2pX9jPDbRD+fbQRlXsb4lUIa9nUOGFe3m54U9gAtfqWUx2Q6DJujYvky4jRr953hjrTdDPLdSl/7L5R0JGACgpFG9zpfCdTprG8ScxMtfqWUJdIzHfwcFcuq3TGs3XeKJqk7GeS7lZ72cAIdyZgSFZA7+zvfH1CzHdjsVkcuNrT4lVKWS890sOVoLKv2xLBu7ymapfzCQN8tdLftxN+k4ChVBVvjQc4/Anp56C3T4ldKFSoZmQ62HrvEqj0x/Lj3OCFXtzLQdwudbRH4mnQcwTWxNRkEDftCjVb6SuAmaPErpQqtTIdh67FYvtlzlg17omh59Wf6+2ylo203djJxBJTD1uBu55VBdbtDYJncH1Rp8SulioZMhyH8uPOVwMa9R2l45Re623fSw2cXwSYBI3ZMzbbYGvSGBr2hQn09JJQDLX6lVJFjjOHg2UTWH7rAjwdjyDwVTmfZQQ97BHeIc+DGjODa+NzRBxr0hFod9DLRLLT4lVJFXmJKOpuOXGT9oQscOLifJslb6WbbyV32ffiTRqZPSaReN2wNe0P9nlCqktWRLaXFr5QqVowxHDrnfDXw84GT+J/aRBfZQXd7BFUkFoC0yiH4NeoLDXtDlWZed0hIi18pVaw5Xw3E8uOhc0Qf/IVmyVvpYd9Bc1sUNgypJari26g3toZ9oE4nrxg+QotfKeU1jDEcPpfEukPn2bn/EGVOr6eL7KCzfTclSCXDHkhm7c7439nXeaVQUBWrI7uFFr9SymslpKSzKfIiGw5Ek3RoPWGpW+lu30kNuQhAcoVmBDS+B9sdfYrVISEtfqWUwvkRk/tjElh/8BxR+36h+vn1dLPtIMR1SOhqQGVsDXvj3/ieIn9ISItfKaWycflKGhsiLxC+7zByZDXtMn6ho20PpSSFDJs/yZXDKNGwCz63d4bqLcDua3XkPNPiV0qpXGQ6DLuj49hw4DSx+36g9qVNtLPtp5HtJABptkCSq7Si1B1d8bm9E1QNKdQji2rxK6VUPsUmpbLt2CX2RB4l9cgGaiVup63sp4HtNACp9pIkV25FqUZd8a3b2Xl+oBCNKaTFr5RStyguOY1txy6x9/AR0qI2cFt8OG1lP3VtMQCk2INIqtKaoDu64l+vM1RuAjabZXm1+JVSqoAlpKQTfvwS+w4eIi1qA9XjttNG9lHHdg6Aqz6lSarYkpL1O1CibgeoFurRk8Va/Eop5WZJqRlsP3GZ/Qf2kx61gapx4YRymHq2MwBkiA/xwXdiq9WWMg3vQmq2c+uwElr8SinlYSnpmew6FcfeI0e5cmQzpc6H08RxkOZyFH9JByA+oAbp1VoR3LAjvrXbQcU7CuzwkBa/UkpZzOEwHL2YxM6j5zh3aCv2079Q5+peWtoOUVESALhqDyKxQigl6rWnVL0OUD0M/Erc1Pq0+JVSqhC6dCWN7ccvEXV4NxnHN1PxcgQhHKahLRqAA13eoVGX+27qsXMq/sJ7AapSSnmBciX9uLtxFe5uXAXoSWpGJntPJ/DBkRMkRG6iX72OBb5OLX6llCpE/H3stKxVlpa1ykL3ELesw7oLTJVSSllCi18ppbyMFr9SSnkZS47xi8hxIBHIBDKyO+uslFLKPaw8udvVGHPRwvUrpZRX0kM9SinlZawqfgOsFpHtIjIxuxlEZKKIhItI+IULFzwcTymlii+rir+DMaYF0AeYKiKdrp/BGPOuMSbMGBNWsWJFzydUSqliyvIhG0RkDpBkjHn1BvNcAE7c5CoqAIXxXILmyh/NlT+aK38Kay64tWy1jDG/23P2+MldESkJ2Iwxia7bPYG5N1omu+D5WF94YbxqSHPlj+bKH82VP4U1F7gnmxVX9VQGlovItfV/bIz51oIcSinllTxe/MaYo0BzT69XKaWUkzdczvmu1QFyoLnyR3Plj+bKn8KaC9yQzfKTu0oppTzLG/b4lVJKZaHFr5RSXqbYFL+I9BaRQyJyRERmZXO/v4gscd2/VURqeyDTbSKyTkQOiMg+EZmezTxdRCReRCJcX8+5O5drvcdFZI9rnb/7XEtxesu1vXaLSAsPZGqYZTtEiEiCiMy4bh6PbC8RWSgi50Vkb5Zp5URkjYhEuv4tm8OyY1zzRIrIGA/kekVEDrp+TstFpEwOy97wZ+6GXHNE5HSWn1XfHJa94XPXDbmWZMl0XEQicljWndsr227w2O+YMabIfwF2IAq4HfADdgF3XjfPI8AC1+37gCUeyFUVaOG6HQQcziZXF+ArC7bZcaDCDe7vC3wDCNAW2GrBz/QszjegeHx7AZ2AFsDeLNNeBma5bs8C/pbNcuWAo65/y7pul3Vzrp6Aj+v237LLlZefuRtyzQFm5uHnfMPnbkHnuu7+vwPPWbC9su0GT/2OFZc9/tbAEWPMUWNMGvApMOC6eQYAH7huLwW6i+vNBO5ijIkxxuxw3U4EDgDV3bnOAjQA+NA4bQHKiEhVD66/OxBljLnZd2zfEmPMBuDSdZOz/g59AAzMZtFewBpjzCVjzGVgDdDbnbmMMauNMRmub7cANQpqfbeSK4/y8tx1Sy7X83848ElBrS+vbtANHvkdKy7FXx04leX7aH5fsL/O43qSxAPlPZIOcB1aCgW2ZnN3OxHZJSLfiEhjD0XKbaC8vGxTd7qPnJ+QVmwvgMrGmBhwPnGBStnMY/V2ewjnK7Xs5Do4ohtMcx2CWpjDYQsrt1dH4JwxJjKH+z2yva7rBo/8jhWX4s9uz/3661TzMo9biEgpYBkwwxiTcN3dO3AezmgO/ANY4YlM5D5QnpXbyw/oD3yezd1Wba+8snK7/RHIABbnMEuugyMWsPlAXSAEiMF5WOV6lm0v4H5uvLfv9u2VSzfkuFg20/K1zYpL8UcDt2X5vgZwJqd5RMQHCObmXprmi4j44vzBLjbGfHH9/caYBGNMkuv2KsBXRCq4O5cx5ozr3/PAcpwvubPKyzZ1lz7ADmPMuevvsGp7uZy7drjL9e/5bOaxZLu5TvD1Ax4wrgPB18vDz7xAGWPOGWMyjTEO4L0c1mfV9vIBBgNLcprH3dsrh27wyO9YcSn+X4D6IlLHtbd4H7DyunlWAtfOfg8FfsjpCVJQXMcQ/w0cMMa8lsM8Va6daxCR1jh/JrFuzlVSRIKu3cZ5cnDvdbOtBB4Up7ZA/LWXoB6Q456YFdsri6y/Q2OAL7OZ5zugp4iUdR3a6Oma5jYi0ht4GuhvjEnOYZ68/MwLOlfWc0KDclhfXp677tADOGiMic7uTndvrxt0g2d+x9xxxtqKL5xXoRzGeYXAH13T5uJ8MgAE4Dx0cATYBtzugUx34XwJthuIcH31BSYDk13zTAP24byaYQvQ3gO5bnetb5dr3de2V9ZcAsxzbc89QJiHfo4lcBZ5cJZpHt9eOP/wxADpOPewxuM8J7QWiHT9W841bxjwryzLPuT6PTsCjPNAriM4j/le+x27dvVaNWDVjX7mbs61yPW7sxtnoVW9Ppfr+989d92ZyzX9P9d+p7LM68ntlVM3eOR3TIdsUEopL1NcDvUopZTKIy1+pZTyMlr8SinlZbT4lVLKy2jxK6WUl9HiV8rNxDmi6FdW51DqGi1+pZTyMlr8SrmIyCgR2eYaf/0dEbGLSJKI/F1EdojIWhGp6Jo3RES2yP/GwC/rml5PRL53DSK3Q0Tquh6+lIgsFee4+YvdPTKsUjeixa8UICKNgBE4B+YKATKBB4CSOMcNagH8CMx2LfIh8LQxphnOd6dem74YmGecg8i1x/muUXCOvjgD55jrtwMd3P6fUioHPlYHUKqQ6A60BH5x7YwH4hwgy8H/BvL6CPhCRIKBMsaYH13TPwA+d43tUt0YsxzAGJMC4Hq8bcY1Low4P/GpNrDR/f8tpX5Pi18pJwE+MMY885uJIs9eN9+Nxji50eGb1Cy3M9HnnrKQHupRymktMFREKsGvn31aC+dzZKhrnpHARmNMPHBZRDq6po8GfjTO8dSjRWSg6zH8RaSER/8XSuWB7nUoBRhj9ovIn3B+4pIN52iOU4ErQGMR2Y7zU9tGuBYZAyxwFftRYJxr+mjgHRGZ63qMYR78byiVJzo6p1I3ICJJxphSVudQqiDpoR6llPIyusevlFJeRvf4lVLKy2jxK6WUl9HiV0opL6PFr5RSXkaLXymlvMz/A+U/kQwynExnAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "from IPython import display\n",
    "import torch.utils.data as Data\n",
    "\n",
    "num_inputs = 200\n",
    "#1类\n",
    "x1 = torch.normal(2,1,(10000,num_inputs))\n",
    "y1 = torch.ones(10000,1)#标签1\n",
    "x1_train = x1[:7000]\n",
    "x1_test = x1[7000:]\n",
    "#0类\n",
    "x2 = torch.normal(-2,1,(10000,num_inputs))\n",
    "y2 = torch.zeros(10000,1)#标签1\n",
    "x2_train = x1[:7000]\n",
    "x2_test = x1[7000:]\n",
    "#合并数据----按行合并，即dim=0，且分训练集和测试集\n",
    "    ##合并训练集数据（包括特征和标签）\n",
    "trainfeatures=torch.cat((x1_train,x2_train),0).type(torch.FloatTensor)\n",
    "trainlabels = torch.cat((y1[:7000],y2[:7000]),0).type(torch.FloatTensor)\n",
    "    ##合并测试集数据\n",
    "testfeatures=torch.cat((x1_test,x2_test),0).type(torch.FloatTensor)\n",
    "testlabels=torch.cat((y1[7000:],y2[7000:]),0).type(torch.FloatTensor)\n",
    "print(trainfeatures.shape,trainlabels.shape,testfeatures.shape,testlabels.shape)\n",
    "\n",
    "#设置批量大小\n",
    "batch_size = 50\n",
    "#'''构造训练数据迭代器'''\n",
    "#将训练数据的特征和标签组合---形成训练数据集\n",
    "dataset1 = Data.TensorDataset(trainfeatures,trainlabels)\n",
    "train_iter = Data.DataLoader(\n",
    "    dataset=dataset1,#torch TensorDataset format\n",
    "    batch_size = batch_size,#mini batch size\n",
    "    shuffle =True,#是否打乱数据（训练集一般需要打乱）\n",
    "    num_workers=0,#多线程读取数据，注意在windows下需要设置为0\n",
    ")\n",
    "#'''构造测试数据迭代器'''\n",
    "dataset2 = Data.TensorDataset(testfeatures,testlabels)\n",
    "test_iter = Data.DataLoader(\n",
    "    dataset = dataset2,\n",
    "    batch_size = batch_size,\n",
    "    shuffle = False,#测试集一般不需要打乱\n",
    "    num_workers=0,\n",
    ")\n",
    "\n",
    "#定义超参数\n",
    "# num_inputs=200\n",
    "num_hiddens = 100\n",
    "num_outputs = 1\n",
    "#定义参数\n",
    "#参数初始化时，将参数的方差调大，增加不稳定性，方式初始参数刚好与理想参数接近的情况，可以更好的观察模型训练的效果。\n",
    "W1 = torch.tensor(np.random.normal(0, 1, (num_inputs,num_hiddens)), dtype=torch.float32)  \n",
    "b1 = torch.zeros(1, dtype=torch.float32)#wx+b时，会调用广播算法  \n",
    "W2 = torch.tensor(np.random.normal(0, 1, (num_hiddens,num_outputs)), dtype=torch.float32)  \n",
    "b2 = torch.zeros(1, dtype=torch.float32)  \n",
    "params = [W1,b1,W2,b2]\n",
    "for param in params:\n",
    "    param.requires_grad_(requires_grad = True)#设置为true，追踪并记录所有在计算图上的操作（正向积累）\n",
    "def relu(x):\n",
    "    x=torch.max(x,torch.tensor(0.0))\n",
    "    return x\n",
    "def net (X):\n",
    "    X=X.view((-1,num_inputs))\n",
    "    H=relu(torch.mm(X,W1)+b1)\n",
    "    return torch.mm(H,W2)+b2 \n",
    "# 定义二分类交叉熵损失函数\n",
    "loss = torch.nn.BCEWithLogitsLoss()\n",
    "def SGD (params,lr,batch_size):\n",
    "    for param in params:\n",
    "        param.data -= lr * param.grad/batch_size\n",
    "#记录列表（list），存储训练集和测试集上经过每一轮次，loss的变化\n",
    "def train (net,train_iter,test_iter,loss,num_epochs,batch_size,params = None,lr=None,optimizer=None):\n",
    "    train_loss=[]\n",
    "    test_loss=[]\n",
    "    for epoch in range(num_epochs):#外循环控制循环轮次\n",
    "        train_l_sum=0.0#记录训练集上的损失\n",
    "        test_l_sum=0.0\n",
    "        n =0.0\n",
    "        #step1在训练集上，进行小批量梯度下降更新参数\n",
    "        for X,y in train_iter:#内循环控制训练批次\n",
    "            y_hat = net(X)\n",
    "            #保证y与y_hat维度一致，否则将会发生广播\n",
    "            l = loss(y_hat,y.view(-1,1))#这里计算出的loss是已经求过平均的，l.size = torch.Size([]),即说明loss为表示*标量*的tensor`\n",
    "            #梯度清零\n",
    "            if optimizer is not None:\n",
    "                optimizer.zero_grad()\n",
    "            elif params is not None and params[0].grad is not None:\n",
    "                for param in params:\n",
    "                    param.grad.data.zero_()\n",
    "            l.backward()\n",
    "            if optimizer is None:\n",
    "                SGD(params,lr,batch_size)\n",
    "            else:\n",
    "                optimizer.step()\n",
    "             #计算每个epoch的loss\n",
    "             #train_l_sum += l.item()\n",
    "             #n+=y.shape[0]\n",
    "        #step2 每经过一个轮次的训练， 记录训练集和测试集上的loss\n",
    "        #注意要取平均值，loss已经默认求了平均值，因此我们不用再老费苦心，直接apply在测试集和训练集上。\n",
    "        test_l_sum = loss(net(testfeatures),testlabels).item()\n",
    "        train_l_sum = loss(net(trainfeatures),trainlabels).item()\n",
    "        train_loss.append(train_l_sum)\n",
    "        test_loss.append(test_l_sum)\n",
    "        print(\"epoch %d , train_loss %.6f , test_loss %.6f\"%(epoch+1,train_loss[epoch],test_loss[epoch])) \n",
    "    return train_loss, test_loss\n",
    "lr = 0.01\n",
    "num_epochs = 20\n",
    "train_loss,test_loss = train(net,train_iter,test_iter,loss,num_epochs,batch_size,params,lr)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "x=np.linspace(0,len(train_loss),len(train_loss))\n",
    "plt.plot(x,train_loss,label=\"train_loss\",linewidth=1.5)\n",
    "plt.plot(x,test_loss,label=\"test_loss\",linewidth=1.5)\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.ylabel(\"loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1,train_loss 2.296501,test_loss 2.290209,train_acc 0.258117,test_acc 0.371500\n",
      "epoch 2,train_loss 2.282973,test_loss 2.274076,train_acc 0.448783,test_acc 0.507300\n",
      "epoch 3,train_loss 2.263522,test_loss 2.249873,train_acc 0.526783,test_acc 0.541800\n",
      "epoch 4,train_loss 2.233911,test_loss 2.212931,train_acc 0.536483,test_acc 0.555000\n",
      "epoch 5,train_loss 2.189217,test_loss 2.157996,train_acc 0.552650,test_acc 0.572700\n",
      "epoch 6,train_loss 2.124435,test_loss 2.080313,train_acc 0.582850,test_acc 0.606100\n",
      "epoch 7,train_loss 2.035622,test_loss 1.976561,train_acc 0.625350,test_acc 0.654000\n",
      "epoch 8,train_loss 1.920703,test_loss 1.846080,train_acc 0.671367,test_acc 0.690200\n",
      "epoch 9,train_loss 1.781345,test_loss 1.693639,train_acc 0.697367,test_acc 0.715100\n",
      "epoch 10,train_loss 1.625559,test_loss 1.531168,train_acc 0.714317,test_acc 0.726800\n",
      "epoch 11,train_loss 1.467003,test_loss 1.373927,train_acc 0.726817,test_acc 0.739100\n",
      "epoch 12,train_loss 1.319348,test_loss 1.233371,train_acc 0.740500,test_acc 0.759600\n",
      "epoch 13,train_loss 1.190471,test_loss 1.113681,train_acc 0.758083,test_acc 0.775900\n",
      "epoch 14,train_loss 1.081945,test_loss 1.014182,train_acc 0.775867,test_acc 0.788700\n",
      "epoch 15,train_loss 0.991912,test_loss 0.932015,train_acc 0.790083,test_acc 0.801600\n",
      "epoch 16,train_loss 0.917369,test_loss 0.864087,train_acc 0.801067,test_acc 0.810500\n",
      "epoch 17,train_loss 0.855328,test_loss 0.807398,train_acc 0.809850,test_acc 0.818300\n",
      "epoch 18,train_loss 0.803319,test_loss 0.759684,train_acc 0.817600,test_acc 0.827200\n",
      "epoch 19,train_loss 0.759318,test_loss 0.719033,train_acc 0.823800,test_acc 0.834600\n",
      "epoch 20,train_loss 0.721699,test_loss 0.684160,train_acc 0.829800,test_acc 0.840600\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3xUVfrH8c+T3gskQAIEQu81NOkWulKUIiJdxLLqb1dE3V13111Xd9deAEEBRSmKggoiKFJFSkLvCT3U0AIBAinn98cdNMBMCJqZSTLP+/WaV5K55955MpnMd+69554jxhiUUkp5Li93F6CUUsq9NAiUUsrDaRAopZSH0yBQSikPp0GglFIezsfdBdyqqKgoU7lyZXeXoZRSxUpSUtJJY0y0vWXFLggqV65MYmKiu8tQSqliRUQOOFqmh4aUUsrDaRAopZSH0yBQSikPV+zOESilSp6srCxSU1PJzMx0dynFXkBAABUqVMDX17fA62gQKKXcLjU1ldDQUCpXroyIuLucYssYw6lTp0hNTSU+Pr7A6+mhIaWU22VmZlK6dGkNgd9JRChduvQt71lpECiligQNgcLxW55Hjzk0lHLiPF9vOkq5sADKhvlTNiyAsmEBlA72w8tLX4BKKc/lMUFwfMtSmq94mdTcaDaYaA6bKFJNNMckGhNSjqjw4F9DIjyAsqEBlAv/NTRC/H30E4tSqkTymCBoXTmE3L2BmDPb8b544pplOVe8OXU6miOno9mfU5r92aX42RYWh0w0x0wp/Pz8KRceQOXSwVSJCiY+Opj4qGCqRIVQNsxfQ0KpYuzs2bNMnz6dRx999JbW69atG9OnTyciIuKW1hs6dCg9evTgvvvuu6X1nMVjgoCqHfGq2tH6PusSpKfC2YNw9iDe6Ycoc/YgZc4epNHZ3ZjzRxF+nbktFy/O+0ZzPKssO49UZM2e8szKrsguU5HL+BHk520FRHTekAghPiqY8MCCd+FSSrnH2bNnGTdu3A1BkJOTg7e3t8P1vv32W2eX5hKeEwR5+QZCVHXrZodkX4FzvwaF19lDhJ89SPjpvdQ4sYx7vDPAG4x4kR5UmUN+Vdlh4lh3oDwzt5TjhPn100HpYD+qXN17iA6hUcUIGlaIINDP8YtLKU/2j2+2sf3IuULdZp3YMP52d12Hy5999ln27NlDo0aN8PX1JSQkhJiYGDZu3Mj27dvp1asXhw4dIjMzkyeffJJRo0YBv459lpGRQdeuXWnTpg2rVq2ifPnyfPXVVwQGBt60tsWLF/P000+TnZ1Ns2bNGD9+PP7+/jz77LN8/fXX+Pj40KlTJ1599VU+//xz/vGPf+Dt7U14eDjLly8vlOfHM4PgZnz8oFQV63a93Fw4ux+ObUWObSHi+FYijm2hfvr39APwh+zAKM6G1eKwfxV2mMqsuxTL8p2l+Cwx29q8l1AnNowmcZE0rWTdYiNu/oJRSjnHK6+8wtatW9m4cSNLly6le/fubN269Ze++JMnT6ZUqVJcunSJZs2ace+991K6dOlrtpGcnMyMGTOYNGkS/fr144svvmDQoEH5Pm5mZiZDhw5l8eLF1KhRg8GDBzN+/HgGDx7MnDlz2LlzJyLC2bNnAXjxxRdZuHAh5cuX/+W+wqBBcKu8vH4NiTr3/Hr/pTNwbCsc34rPsS1EHdtC1OGZNMy5wgAAb3+yK9fmcGQLfvZqzLzTYcxad4ipq/YDEBMeQJNKkTS1hUOd2DB8vbV3r/I8+X1yd5XmzZtfc0HW22+/zZw5cwA4dOgQycnJNwRBfHw8jRo1AqBp06bs37//po+za9cu4uPjqVGjBgBDhgzhvffe4/HHHycgIICRI0fSvXt3evToAUDr1q0ZOnQo/fr1o0+fPoXxqwIaBIUnMBLi21q3q3Ky4ORuKyCObcbn8Hoq7fqQSrnZDPALJbdOe45Et2GVNGLFiQDWHzjD/M1HAQjw9aJBhQhrjyEukiaVIikV7OemX04pzxIcHPzL90uXLuWHH37g559/JigoiA4dOti9YMvf3/+X7729vbl06dJNH8cYY/d+Hx8f1q5dy+LFi5k5cybvvvsuP/74IxMmTGDNmjXMnz+fRo0asXHjxhsC6bfQIHAmb18oW9e6Nexv3ZeZDvuWQ/L3eKUspsKuefQD+kXXhsZ3cCq2PWuza7Iu9QJJB88waflexudaL5Yq0cF0rluOHg1iqBMTpj2VlCokoaGhnD9/3u6y9PR0IiMjCQoKYufOnaxevbrQHrdWrVrs37+flJQUqlWrxrRp02jfvj0ZGRlcvHiRbt260bJlS6pVqwbAnj17aNGiBS1atOCbb77h0KFDGgTFUkA41L7buhkDabsg5XtI+QHWTqR0zrt09Q2ia3w7aHYnl/vezqYLkSQdOMOqPSeZuHwv45fuoUp0MHc3iOXuhjFUKxPq7t9KqWKtdOnStG7dmnr16hEYGEjZsmV/WdalSxcmTJhAgwYNqFmzJi1btiy0xw0ICGDKlCn07dv3l5PFo0eP5vTp0/Ts2ZPMzEyMMbzxxhsAjBkzhuTkZIwx3HHHHTRs2LBQ6hBHuya/e8MiFYGPgXJALjDRGPPWdW0eAMbafswAHjHGbMpvuwkJCabEzlB25QLsW2GFQsr3cGa/dX/palDtTqh2F6fKtuK7HSf5ZtMR1uw7jTFQq1wodzeMpUeDGCqVDs73IZQqinbs2EHt2rXdXUaJYe/5FJEkY0yCvfbODIIYIMYYs15EQoEkoJcxZnueNrcBO4wxZ0SkK/B3Y0yL/LZbooMgL2Pg9F4rFJK/h/0rIDsTwipAwjBoMoTjuaF8u+Uo8zYfJenAGQAaVginR4NYujeI0Z5IqtjQIChcRSYIbnggka+Ad40x3ztYHglsNcaUz287HhME18u6ZAXCug9g3zLw9oO6vaH5KCjflMPpmczffIRvNh1ly+F0ABIqRXJ3w1i61i9HmdAAN/8CSjlWUoPgscce46effrrmvieffJJhw4Y59XGLZBCISGVgOVDPGGP3ShEReRqoZYwZaWfZKGAUQFxcXNMDBxzOwewZ0nZZgbBxOlzJgNjG0OwhqNcHfAPZf/IC8zYfYd7mo+w8dh4vgZZVSnN3w1h6NSqvF7OpIqekBoG7FLkgEJEQYBnwkjHmSwdtOgLjgDbGmFP5bc9j9wjsuXweNs2EtZPg5C4ILAVNBkPCcIisBEDy8fN8s/ko8zYdYe/JC0SF+PNwuyo80DKOID/tK6CKBg2CwlWkgkBEfIF5wEJjzOsO2jQA5gBdjTG7b7ZNDQI7jLG6pK6bBDvnWz/X7ArNRkKVjuDlhTGGdfvP8PbiZFamnCQqxI+H2lbhwVaVNBCU22kQFK5bDQKnvQOI1cn9Q6yTwY5CIA74EniwICGgHBCBKu2tW3oqJE6BpKmw61urx1Gzh5BG99M8vhSfjGxB4v7TvLU4mZcX7OT95Xt5qG0VBreqRLC/BoJSnsiZvYbaACuALVjdRwGeB+IAjDETROQD4F7g6kH/bEeJdZXuERRQ9mXY/hWsnQip68A32Lqord0YCIsFIOmAtYewbHcakUG+jGxbhSG3VSZEA0G5mO4RFK4idWjIGTQIfoMjG2DtB7DlM/DyhfbPQMtHrcH1gA0HrUBYsiuNiCBfRraJZ8htlQkN0CG0lWu4Owh+63wEAG+++SajRo0iKCjIYZuro5RGRUX9njIL7FaDQEc18wSxjaHXe/DYWohvBz/8DcbfBnt+BKBxXCRThjXnq8da0zQuklcX7ab1Kz/y1g/JnMvMcnPxSjnf1fkIfos333yTixcvFnJFrqXHADxJqXgYOBN2L4QFY2Fab2uoi87/hog4GlaM4MOhzdiSms5bi5N544fdfLByL8NbxzO8TbxOsqNcY8GzcGxL4W6zXH3o+orDxXnnI7jrrrsoU6YMn332GZcvX6Z379784x//4MKFC/Tr14/U1FRycnL461//yvHjxzly5AgdO3YkKiqKJUuW3LSU119/ncmTJwMwcuRInnrqKbvb7t+/v905CZxBg8AT1egM8e3h53dh+auQ/AO0/RPc9gfwDaB+hXA+GJLA1sPpvL04mbcWJzN55T6GtYlndPsq2stIlTh55yNYtGgRs2fPZu3atRhjuOeee1i+fDlpaWnExsYyf/58wBqMLjw8nNdff50lS5YU6LBPUlISU6ZMYc2aNRhjaNGiBe3bt2fv3r03bPv06dN25yRwBv2P9lS+AdDuaWjQHxb9GZb8CzZ+Al3+AzW7AFCvfDgTByew7Ug67yxO4e3FyXy98TCv9WtI00ql3PwLqBIrn0/urrBo0SIWLVpE48aNAcjIyCA5OZm2bdvy9NNPM3bsWHr06EHbtm1vsqUbrVy5kt69e/8yzHWfPn1YsWIFXbp0uWHb2dnZduckcAY9R+DpIipCv4/hwbnWsBUz+sP0/tY4RzZ1Y8OZ8GBTZjzUkuxcw30Tfublb3eQmZXjxsKVcg5jDM899xwbN25k48aNpKSkMGLECGrUqEFSUhL169fnueee48UXX/xN27bH3ravzklw7733MnfuXLp06fJ7fzWHNAiUpWpHGP0TdPoX7F8J77WAH/8FV349Cdaqamm+e6odA5rF8f7yvdz9zkq2pKa7sWilCkfe+Qg6d+7M5MmTycjIAODw4cOcOHGCI0eOEBQUxKBBg3j66adZv379DeveTLt27Zg7dy4XL17kwoULzJkzh7Zt29rddkZGBunp6XTr1o0333yTjRs3OueXRw8Nqbx8/KzzBPXug+9fgOX/s4aw6Pxv66SyCCH+Przcpz6d65Zl7Beb6TXuJx7vWI3Hb6+mU2uqYivvfARdu3Zl4MCBtGrVCoCQkBA++eQTUlJSGDNmDF5eXvj6+jJ+/HgARo0aRdeuXYmJibnpyeImTZowdOhQmjdvDlgnixs3bszChQtv2Pb58+ftzkngDHodgXJs/0+w4Bk4vtUaqqLrfyG6xi+L0y9m8fdvtjFnw2HqlQ/jtb6NqFlOJ8lRt87d1xGUNHodgSo8lVvDqGVWABxeb117sHqCNZYREB7kyxv9GzFhUFOOns3k7ndWMmHZHnJyi9eHC6U8nR4aUvnz9oEWD0PdPvDNE/DdWOtK5R5vgJ91JWWXeuVoVjmSP8/ZyisLdrJo2zFe69eI+CidLU15lhYtWnD58uVr7ps2bRr169d3U0UFo0GgCiYkGvp/CitehSX/hhPbof8nvwx3XTrEn/GDmvDVxiO88NVWur61nOe61ubBlpXw8hI3F6+KA2MM1liVxdeaNWvcXYLDnkn50UNDquC8vKxxigbOgjMHYGIH2PPryTERoVfj8iz6v/a0rFKav329jUEfriH1TPG+/F45X0BAAKdOnfpNb2LqV8YYTp06RUDArc1IqCeL1W9zag/MfMCaEOfOv8NtT1jDYdsYY5i17hD/nLcdEeGFHnXom1Ch2H/iU86RlZVFamoqmZmZ7i6l2AsICKBChQr4+l47JIyOPqqc43IGfPUYbJ9rzZ98z7vgH3JNk0OnLzJm9iZW7z3N7bXK8FrfhkQG+7mpYKU8l/YaUs7hHwJ9p8JdL1pzH3x4l7WnkEfFUkFMH9mSv91dh5UpJ7l3/CoOntJDRUoVJU4LAhGpKCJLRGSHiGwTkSfttBEReVtEUkRks4g0cVY9yklEoPWTMOgLOH8UJnWE3YuuaeLlJQxrHc/0kS04ffEKfcb/xKZDzhtASyl1a5y5R5AN/MkYUxtoCTwmInWua9MVqG67jQLGO7Ee5UxVb7euOYiIg+n9YNl/ITf3miYJlUvxxSO3EejnzYCJq1m847ibilVK5eW0IDDGHDXGrLd9fx7YAZS/rllP4GNjWQ1EiEiMs2pSThZZCYYvggb9YMlLMGsQZJ67pknV6BC+fKQ11cuG8NDHiXy65oCDjSmlXMUl5whEpDLQGLi+k2154FCen1O5MSxUceIXBL3ft65G3v0dTLod0nZd0yQ61J+Zo1rSoWYZ/jxnK//5bie5ejWyUm7j9CAQkRDgC+ApY8y56xfbWeWGdwQRGSUiiSKSmJaW5owyVWESsa5GHvINZJ61wmDHN9c0CfLzYeKDTRnYIo7xS/fwx882ciU718EGlVLO5NQgEBFfrBD41BjzpZ0mqUDFPD9XAI5c38gYM9EYk2CMSYiOjnZOsarwXR2rKLqmdZho8T+vOW/g4+3FS73qMaZzTeZuPMKQyWtJv6RzJCvlas7sNSTAh8AOY8zrDpp9DQy29R5qCaQbY446qyblBuHlYdgCaDLYGp7i6z9cEwYiwmMdq/FG/4YkHjhN3wmrOHL2khsLVsrzOHOsodbAg8AWEbk6o8LzQByAMWYC8C3QDUgBLgLDnFiPchcff7jnHQgrD0tftu675x1ryAqb3o0rUDY0gIenJdF73E9MGdqcOrFhbipYKc+iVxYr11r6ihUGjR6whYH3NYt3HjvHsCnrOJ+ZzfhBTWhbXQ8FKlUY9MpiVXR0eBY6PAcbP7UdJrp23uNa5cL48tHbqBAZyLAp65idlOqmQpXyHBoEyvVuEgYx4YF8NroVLaqU4unPN/H24mQdlVIpJ9IgUO7R4Vno8LwVBl89fkMYhAX4MmVoc/o0Kc/r3+/muS+3kJWj3UuVcgadmEa5T4ex1tel/7a+9nz3mnMGfj5evNa3IRUiAnn7xxSOpmcyflATgvz0ZatUYdI9AuVeHcZaewabpltDWl+3ZyAi/LFTTV7pU58VyWk8PC2Jy9k5DjamlPot9KOVcr8b9gzeu6E30YDmcXh7CWNmb+bJGRt5d2BjfLz1c4xShUGDQBUNHcZaQ1Msecn62U4Y9E2oyPnMbF6ct51nv9zCf+9toPMhK1UINAhU0dH+GevrkpfAGOg17oYwGN4mnvOZ2bzxw25C/H342911dPpLpX4nDQJVtOQNA7AbBk/cUY1zmVl8uHIf4YG+/N9dNVxcpFIliwaBKnraPwMILPmX9fN1YSAi/KV7bc5nZvHW4mRCA3wY2baKe2pVqgTQIFBFU/sx1td8wuDlPg3IuJzNv+bvICzAl37NKtrZkFLqZjQIVNF1TRgY6DX+mjDw9hLe6N+IjMtJPPvlZkICfOhWXye4U+pWaf87VbS1HwMd/wKbZ8HcR264zsDfx5sJg5rQJC6SJ2duYNlunbhIqVulQaCKvvZj4HZbGHz33A2Lg/x8mDysGTXKhvLwtETW7T/thiKVKr40CFTx0G4MtHwM1r4PqyfcsDgswJePhjcnNiKQ4VPWsfVwuhuKVKp40iBQxUenf0LN7rDwOdj13Q2Lo0L8+WREC8ICfRkyeS170jLcUKRSxY8zp6qcLCInRGSrg+XhIvKNiGwSkW0iorOTqfx5ecO9k6BcA5g9HI5uuqFJbEQgn4xsgYgw6IM1pJ656IZClSpenLlHMBXoks/yx4DtxpiGQAfgNRHxc2I9qiTwC4b7Z0JgJEzvD+mHb2gSHxXMtBHNuXA5m0EfrCHt/GU3FKpU8eG0IDDGLAfyO2tngFDbJPchtrbZzqpHlSBhMTBwFlzOgBn9ra/XqR0TxpRhzTl+7jIPfriG9ItZbihUqeLBnecI3gVqA0eALcCTxhi7M4+IyCgRSRSRxLQ07R6ogHL1oO9UOL7dOkyUe+PQ1E0rRTJxcFP2pl1g2NS1XLisnzOUssedQdAZ2AjEAo2Ad0UkzF5DY8xEY0yCMSYhOlonM1c21e+Ebv+F5IV2u5UCtK0ezdv3N2bjobM6l4FSDrgzCIYBXxpLCrAPqOXGelRx1GwktHrcYbdSgC71yvHf+xqyMuUkz32xRec/Vuo67gyCg8AdACJSFqgJ7HVjPaq4uuvFfLuVAtzXtAJ/vKsGX244zLile1xcoFJFmzO7j84AfgZqikiqiIwQkdEiMtrW5J/AbSKyBVgMjDXGnHRWPaoEK0C3UoA/3F6Nno1i+d/CXXy39aiLi1Sq6JLitpuckJBgEhMT3V2GKorOH4NJd4DJgZGLIbz8DU0ys3IYMHE1u46d5/PRrahXPtwNhSrleiKSZIxJsLdMryxWJUdouZt2Kw3w9Wbi4KZEBvky8qNEjp/LdEOhShUtGgSqZClAt9IyoQF8OLQZ5zKzeOjjRC5d0Z5EyrNpEKiSpwDdSmvHhPH2gMZsOZzO059vIje3eB0iVaowaRCokilvt9I179ttcmedsjzXtRbztxzlzcXJLi5QqaJDZyhTJdddL8KZ/fDdsxBRCWreOPTVQ22rkHw8g7cXJ1M1OpiejW48waxUSad7BKrk8vKGPhPz7VYqIrzUuz7N40sxZvZmNhw844ZClXIvDQJVsvkFWz2Jro5Weu7IjU18vJgwqCnlwgJ46OMkDp+95IZClXIfDQJV8oWWgwc+g8vnYdYgyLqxy2ipYD8+HJLA5awcRn6UqAPUKY+iQaA8Q9m60HsCHE6Cb/8Edi6krF42lHcfaMKuY+d4cuZG7UmkPIYGgfIcte+G9mNhwyew7gO7TdrXiOZvd9flhx3H+c/CnS4uUCn30F5DyrO0fxaObrZ6EpWpDZXb3NBkcKtKJJ84z/vL9lItOoS+CRXdUKhSrqN7BMqzeHlZPYki4+GzIXD20A1NRIS/3V2XNtWieH7OFtbuy2+iPaWKPw0C5XkCwuD+GZBzBWY9AFk39hLy9fbivQeaULFUEA9PS+TgqYtuKFQp19AgUJ4pqjr0mWQdJvrmSbsnj8MDfZk8pBm5BoZ/tI5zmTrvsSqZNAiU56rZBTr+GTbPgtXj7DapHBXMhEFN2X/yAo9P30B2jt1ptZUq1pw5Mc1kETkhIlvzadNBRDaKyDYRWeasWpRyqO2frN5Ei/4Ke5fabdKqamn+1asey3en8dK3O1xbn1Iu4Mw9gqnAjYO72IhIBDAOuMcYUxfo68RalLLPywt6jYeoGvD5MGtsIjsGNI9jeOt4pvy0n1nrDrq2RqWczGlBYIxZDuTX3WIg1uT1B23tTzirFqXy5R8KAz61ZjabOQiuXLDb7PlutWhXI5q/zN3Kuv3ak0iVHO48R1ADiBSRpSKSJCKD3ViL8nSlq8K9k+H4Vvjqcbsnj328vXjn/sZUjAxi9LQkUs9oTyJVMrgzCHyApkB3oDPwVxGpYa+hiIwSkUQRSUxLS3NljcqTVL8T7vwbbPsSfnrLbpPwQF8mDUngSk4uD32cpGMSqRLBnUGQCnxnjLlgjDkJLAca2mtojJlojEkwxiRER0e7tEjlYVo/BXV7ww9/h+Qf7DapGh3CuwOtMYn+9JnObqaKP3cGwVdAWxHxEZEgoAWgXTKUe4lAz/esQeq+GA6n9tht1r5GNM93q813247p7Gaq2HNm99EZwM9ATRFJFZERIjJaREYDGGN2AN8Bm4G1wAfGGIddTZVyGb9g6+SxeMHMB6zhq+0Y0Saevk0r8PbiZOZvPuriIpUqPGLsnBQryhISEkxiYqK7y1CeYM8S+KQP1OwG/aZZXU2vczk7h4GT1rDtSDqzR99GvfLhbihUqZsTkSRjTIK9ZXplsVKOVO0Id/0Tds6DFa/ZbeLv482EQU0pFeTHQx8ncuL8jZPeKFXUaRAolZ9Wj0H9frDkJdj1nd0m0aH+TBqSwNmLWYyelsTl7BwXF6nU71OgIBCRJ0UkTCwfish6Eenk7OKUcjsRuOdtiGkAXz4EabvtNqsbG87r/Rqy/uBZnv9yK8XtkKvybAXdIxhujDkHdAKigWHAK06rSqmixDcQ+n8K3n4woz9ctH9Vcdf6MTx1Z3W+WJ/KByv2ubhIpX67ggaB2L52A6YYYzbluU+pki+iotWTKD0VPh8KOfaHpH7i9up0q1+OlxfsYMkuHTVFFQ8FDYIkEVmEFQQLRSQU0PF4lWeJawk93oR9y6ypLu3w8hJe7duQWuXCeGL6BlJO2O96qlRRUtAgGAE8CzQzxlwEfLEODynlWRo/ALf9AdZ9AGsn2W0S5OfDpCEJ+Pt6MfKjRM5evOLiIpW6NQUNglbALmPMWREZBPwFSHdeWUoVYXf+A2p0gQVjrWsN7CgfEcj7DzblyNlMndBGFXkFDYLxwEURaQg8AxwAPnZaVUoVZV7e1jSXUTXg8yEOh6FoWqkUL/Wux8qUk/xrvo6eooquggZBtrH6w/UE3jLGvAWEOq8spYq4gDAYOBPEG6b3h0tn7Tbrm1CRkW3imbpqPzPW6oQ2qmgqaBCcF5HngAeB+SLijXWeQCnPFVkZ+n9izWo2exjk2B+S+rlutWlfI5q/zt3Kz3tOubREpQqioEHQH7iMdT3BMaA88D+nVaVUcVG5NfR4Hfb8CIv+bLeJt5fwzsDGxEcFM2paIruOaU8iVbQUKAhsb/6fAuEi0gPINMboOQKlAJoMhpaPwZoJkDjFbpOwAF+mDm9OoK83Q6es5Vi6jkmkio6CDjHRD2uo6L5AP2CNiNznzMKUKlY6/ROq3QnfPg37ltttUj4ikKnDmnM+M5uhU9ZyLtP+RWlKuVpBDw39GesagiHGmMFAc+CvzitLqWLGyxvumwylqsJng+H0XrvN6sSGMWFQU1JOZDB6WhJXsrVbqXK/ggaBlzEm7/Xyp25hXaU8Q0C41ZMIYPoAyLR/qU2b6lH8974GrNpzimdm61SXyv0K+mb+nYgsFJGhIjIUmA98m98KIjJZRE6ISL6zjolIMxHJ0UNNqkQoVQX6fQyn98DsEZBrf0jqPk0qMKZzTeZuPML/Fu1ycZFKXaugJ4vHABOBBlgTzE80xoy9yWpTgS75NbB1Q/0PsLAgdShVLMS3g27/g5Tv4fsXHDZ7tENVHmgRx/ile5j2836XlafU9XwK2tAY8wXwxS20Xy4ilW/S7A+2bTYr6HaVKhYShsOJnfDzuxBd0+pZdB0R4cWe9Th+7jIvfL2NMmEBdK5bzg3FKk+X7x6BiJwXkXN2budF5NzveWARKQ/0BiYUoO0oEUkUkcS0tLTf87BKuU7nf0PV22HeH2H/T3abeHsJ79zfmIYVInhixgaSDpxxcZFK3SQIjDGhxpgwO7dQY0zY73zsN4GxxpibzutnjJlojEkwxiRER0f/zodVykW8feC+KRBZCT570LoC2Y5AP28+HJJATIk93NgAABsrSURBVHgAIz9ax960DNfWqTyeO3v+JAAzRWQ/cB8wTkR6ubEepQpfYATcPwtys209iezvSJcO8eej4c3xEmHIlLWknb/s4kKVJ3NbEBhj4o0xlY0xlYHZwKPGmLnuqkcpp4mqZvUkOrkbZg6ELPtXFVcqHczkoc04ef4Kw6eu48Jl+2MXKVXYnBYEIjID+BmoKSKpIjJCREaLyGhnPaZSRVaVDtB7AuxfAbOHOxygrmHFCN4d2JhtR9J5fPp6ncdAuYRYo0sXHwkJCSYxMdHdZSj126yZCAvGQMOB0PM98LL/WWz6moM8P2cL/RMq8sq99RHRKcLV7yMiScaYBHvLCtx9VClVCFqMgsyzsOQl60rkLi+DnTf5gS3iOJp+iXd+TCE2IpAn76zuhmKVp9AgUMrV2o2BS2dg9TgIKgXtn7Hb7I931eDI2Uze+GE3MeEB9GtW0cWFKk+hQaCUq4lAp5esWc2WvAQBEdaewg3NhFfurc+J85k8N2cLZcL86VCzjBsKViWdDhynlDt4ecE970DN7tY5g82f2W3m6+3F+EFNqVk2lEc/Xc/GQ/anxFTq99AgUMpdvH2soasrt4U5o2HXArvNQvx9mDqsGVEh/jz4wRrWH9Srj1Xh0iBQyp18A+D+GRDTED4fCvtX2m1WJiyAmaNaUjrEj8EfriXpwGnX1qlKNA0CpdzNPxQemA0Rlayrj49stNssNiKQmaNaUSbUn8EfrmXtPg0DVTg0CJQqCoJLw4NzIDASPukDabvtNisXbu0ZlAsPYMjktfy855SLC1UlkQaBUkVFeHkYPBfEC6b1hrOH7DYrExbAjFEtqRAZyLCpa1mVctLFhaqSRoNAqaKkdFUY9CVcPg/TekGG/WHXy4RaYVCpVDDDpq5jRbIOz65+Ow0CpYqamAYwcBakH7YOEzmY+zgqxJ/pD7UgPiqYER8lsnTXCbvtlLoZDQKliqJKraD/NDixHWbcD1mX7DYrHeLPjIdaUi06hFEfJ7Fkp4aBunUaBEoVVdXvgt7vw4FV8PkwyMmy2ywy2I/pD7WgRrkQHp6WxA/bj7u4UFXcaRAoVZTVvw+6vwa7F8BXj0Gu/WGpI4L8+HRES2rHhPLIp0ks2nbMxYWq4kyDQKmirtkIuOMF2DwL5j7icM8gPMiXj0e0oG5sOI9+up7vth51caGquNIgUKo4aPNH6PgX2DwTZj4AVy7abRYe6Mu0Ec1pWDGCx6ZvYP5mDQN1c86coWyyiJwQka0Olj8gIpttt1Ui0tBZtShV7IlA+zHQ4w1IXmRdZ3DJ/phDoQG+fDS8OU3iInhi5ga+3nTExcWq4saZewRTgS75LN8HtDfGNAD+CUx0Yi1KlQwJw6HvVDiyHqZ0g3P2P/FbA9U1p2mlSJ6auYG5Gw67tk5VrDgtCIwxywGHg6EYY1YZY65+pFkNVHBWLUqVKHV7wQOfw9mDMLkTnNpjt1mwbdTSFvGl+b/PNvJFUqqLC1XFRVE5RzACsD8GLyAio0QkUUQS09L0CkqlqNIBhnwDVy7Ah50cDlQX5OfD5KHNaF01iqdnb2LKT/sobvOUK+dzexCISEesIBjrqI0xZqIxJsEYkxAdHe264pQqyso3geGLwDcIpvaAfcvtNgv08+aDIQncWbss//hmO8/P2UpWjv1uqMozuTUIRKQB8AHQ0xijwygqdauiqsGIhRBeAT65F7Z/bbdZgK837w9qyiMdqjJj7UEe/HANZy5ccXGxqqhyWxCISBzwJfCgMcb+mLtKqZsLi4Vh30JMI/h8CCRNtdvMy0sY26UWb/RvyPqDZ+n53k8kHz/v2lpVkeTM7qMzgJ+BmiKSKiIjRGS0iIy2NXkBKA2ME5GNIpLorFqUKvGCSsHgr6DqHfDNk7D8VXBwLqB34wrMHNWSi1dy6DNuFUt0sDqPJ8XtxFFCQoJJTNTMUMqunCxrKIrNs6Dlo9DpJfCy/3nv8NlLPPRRIjuPneP5brUZ0SYeEXFxwcpVRCTJGJNgb5nbTxYrpQqRty/0mmCFwOpxMHe0wyEpykcEMvuRVnSqU45/zd/B2C82cyVbTyJ7Ig0CpUoaLy/o/O9fxyeacb/VzdSOID8fxj3QhCdur8ZniakM+mANpzIuu7hg5W4aBEqVRCLQ9k9w91uwZzF83Asu2r++08tL+GOnmrx9f2M2pVonkXcd05PInkSDQKmSrOlQ6PsRHN0IU7o6nAcZ4J6Gscx6uBVXsnPpM+4nFu/QeQ08hQaBUiVdnXtg0Bdw7gi83xaSv3fYtFHFCL5+vA1VokMY+XEi7y/bo1ciewANAqU8QXw7eGgJhJWHT++Dxf+E3By7TcuFB/DZw63oVi+Glxfs5OnPN3M5235bVTJoECjlKaKqwcgfoPEgWPEqfNwTzts//BPo5827Axvz1J3V+WJ9KvdPXE3aeT2JXFJpECjlSXwDoed70HMcpCZah4r2r7TbVER46s4avDewCduPnqPXez+x6dBZFxesXEGDQClP1PgBeGgx+IfCR3fDitcczofcvUEMnz98G8YY+oxfxRvf79ZB60oYDQKlPFXZujBqKdTpBYtfhBkDHHYxrV8hnAVPtaNnw1jeWpzMveNXkXIiw6XlKufRIFDKk/mHwn2TodursOdHeL89pCbZbRoe6Mvr/Rsx/oEmHDp9ke5vr2Dyyn3k5mqvouJOg0ApTycCzR+C4Qutnyd3hjUTHQ5a17V+DAv/rx1tqkXx4rztDPpwDYfPXnJhwaqwaRAopSwVmsLDy6DaHbBgDMweBpftX2FcJjSAD4Yk8J9767Pp0Fm6vLGcL5JS9ZqDYkqDQCn1q6BSMGAG3Pl3a5KbiR3g+Da7TUWE/s3iWPBkO2rFhPKnzzfxyCfrdayiYkiDQCl1LS8vaPN/1pzIlzNg0h2w4VOHzeNKBzFzVCue61qLH3eeoPOby/lhuw5PUZw4c2KaySJyQkS2OlguIvK2iKSIyGYRaeKsWpRSv0Hl1jB6BVRsBl89as1zcOWi3abeXsLD7avy9R9aEx0awMiPExk7ezPnM+0Pga2KFmfuEUwFuuSzvCtQ3XYbBYx3Yi1Kqd8ipAw8OBfajYENn8C4lpDyg8PmtcqFMfex23i0Q1U+TzpE17dWsGavTkde1DktCIwxywH7nZItPYGPjWU1ECEiMc6qRyn1G3l5w+1/gaHfgo8/fHIvzB4BGfanuPT38eaZLrX4fHQrvL2EAZNW8+9vd5CZpeMVFVXuPEdQHsg7Jm6q7b4biMgoEUkUkcS0tDSXFKeUuk7l1jB6JXR4HnZ8De8mQNJHDq9IblqpFN8+0ZaBzeOYuHwv97y7kvUHz7i4aFUQ7gwCe5Oj2u17ZoyZaIxJMMYkREdHO7kspZRDPv7QYSw8sgrK1odvnoCp3SFtl93mwf4+vNS7PlOGNSP9UhZ9xq3iyZkbOKLXHRQp7gyCVKBinp8rAEfcVItS6lZEVYeh86wB7E5sh/GtYcm/ISvTbvOONcvw45868Ifbq/Hd1mPc/tpSXv9+NxevZLu4cGWPO4Pga2CwrfdQSyDdGHPUjfUopW6FiDWk9eOJULc3LPsPTGgN+1bYbR7s78OfOtVk8Z/ac2ftsry9OJnbX13GnA2pOkyFm4mzrgQUkRlAByAKOA78DfAFMMZMEBEB3sXqWXQRGGaMSbzZdhMSEkxi4k2bKaVcbc+PMO//4Mx+aDQIOv3TukDNgcT9p3lx3nY2p6bTsGIEL/SoQ9NKka6r18OISJIxJsHusuJ2SbgGgVJF2JWLsPy/sOodCAiHzi9Dg37W3oMdubmGORsO89+FOzl+7jL3NIxlbNdalI8IdHHhJZ8GgVLKtY5vg2+ehNR1UKUDdH8dSld12PzC5WzeX7aH95fvBeDhdlV4uH1Vgv19XFOvB9AgUEq5Xm4uJH5ozXWQc8W6KO22J8DHz+Eqh89e4j8LdvL1piOUDfPnmc616N24PF5e9vcoVMFpECil3OfcUVjwjHXtQUQctH8WGvQHb8ef9pMOnObFeTvYdOgsDSqE80KPOiRUdny+Qd2cBoFSyv1SFlt7B0c3Qunq0PF5a3Y0L/udF3NzDV9tOsx/Fuzi2LlMejSI4elONakcFeziwksGDQKlVNFgDOycBz++BGk7oFx96PgXqNHZ4Qnli1eyeX/ZXt5fvocr2bl0bxDLI+2rUic2zMXFF28aBEqpoiU3B7Z+AUtesrqbVmgOd/wV4ts5XOXE+Uwmr9zPJ6sPkHE5m9trleHRDlX1kFEBaRAopYqmnCxrVNPl/4NzhyG+Pdz+V2voawfSL2bx8c/7mbJqP6cvXKF5fCke7VCV9jWiEQd7FUqDQClV1GVlQuJkWPEaXDwJNbrC7X+2Dh05cPFKNjPXHmLSir0cTc+kbmwYj3aoRpd65fDWXkY30CBQShUPlzNgzQRY9TZkpkPdPtZJ5ajqDle5kp3L3A2HGb9sD/tOXqBKVDCj21elV+Py+PnoJIxXaRAopYqXS2dg1buwejxkX4KGA61RTyPiHK6Sk2v4busx3luSwvaj54gJD+ChtlUY0LwiQX56YZoGgVKqeMpIg5VvwLoPwORa1x+0HJ3vISNjDMt2pzFuyR7W7j9NqWA/ht1WmcGtKhMe5OvC4osWDQKlVPGWfhhWvg4bp0PWRajUxgqEmt2sGdQcWLf/NOOWpLBkVxoh/j7c17QC9zePo2a5UBcWXzRoECilSoZLZ2D9NFg7CdIPWoeKmo+Cxg9CYITD1bYdSWfi8r0s2HKMKzm5NImL4P7mcfRoEEugn+MgKUk0CJRSJUtONuz61jqxfOAn8A2GRvdDi9H5nlg+feEKX65PZfrag+xNu0BogA+9G5dnQLO4En+BmgaBUqrkOroJVk+ArbOtwe2q3QktHoGqtzscvsIYw9p9p5mx9iDfbj3GlexcGlaMYGDzivRoEFsiRz3VIFBKlXwZJyBxijXiacZxazyjFg9Dw/vBP8ThamcvXuHL9YeZvvYgKScyCPH3oWejWO5vHke98uEu/AWcy21BICJdgLcAb+ADY8wr1y2PAz4CImxtnjXGfJvfNjUIlFL5yr4C2+bAmvFwZAP4h0OTB61zCZGVHK5mjCHpwBmmrz3I/M1HuZydS4MK4QxoFsc9jWIJKeZ7CW4JAhHxBnYDd2FNVL8OuN8Ysz1Pm4nABmPMeBGpA3xrjKmc33Y1CJRSBWIMHFprBcL2rwEDVe+AhgOs3kZ+QQ5XTb+YxZwNqcxYe4hdx88T5OdNz0ax9GlSgaZxkcVyfoT8gsCZEdccSDHG7LUVMRPoCWzP08YAV8/QhANHnFiPUsqTiEBcC+uWnmoNYbFpFnwxAvxCoU5PKxQqtb7hXEJ4kC9DW8cz5LbKbDh0lhlrDjJnw2FmrD1EubAAutWPoXuDGBpXjCiWoXA9Z+4R3Ad0McaMtP38INDCGPN4njYxwCIgEggG7jTGJNnZ1ihgFEBcXFzTAwcOOKVmpVQJl5sLB1bCppmw/Su4kgHhFa15lRsMgOgaDlfNuJzN4h3Hmbf5KMt2pXElJ5fY8AC6N4ihe4NYGlYIL9KD3rnr0FBfoPN1QdDcGPOHPG3+aKvhNRFpBXwI1DPG5Drarh4aUkoViisXYed82DwT9vxoXbkc28TaS6h3LwRHOVz1XGYWP2w/zvzNR1menEZWjqF8RCA9GsTQo0Es9cqHFblQcFcQtAL+bozpbPv5OQBjzMt52mzD2ms4ZPt5L9DSGHPC0XY1CJRShe78Mdgy2wqFY1vAyweq3WWFQo0u4BvgcNX0S1ks2naM+VuOsjL5JNm5hrhSQdaeQv0Y6sYWjVBwVxD4YJ0svgM4jHWyeKAxZlueNguAWcaYqSJSG1gMlDf5FKVBoJRyquPbrENHmz+DjGNWr6N6va1DRxVbOLw2AayuqIu2HWfelqP8lHKSnFxDfFQw3W3nFGqVC3VbKLiz+2g34E2srqGTjTEviciLQKIx5mtbT6FJQAjWieNnjDGL8tumBoFSyiVyc2DfMisUdnxjjXEUGgs1u0Kt7lC5Lfj4OVz99IUrLNx2jPmbj7Jqz0lyDVQsFUiHGmXoWCuaVlWiXDq8hV5QppRSv8flDOt8ws5vIGWxFQr+YVD9LisUqt0FAY6HqDiZcZmF246xZOcJfko5xaWsHPx8vGhVpTQdakbTsWYZKkcFO/VX0CBQSqnCknUJ9i61gmHXAmtGNS9fa77lWt2taxTCYhyunpmVw7r9p1myM42lu0+wN+0CAPFRwbSvEU3HWmVoEV+KAN/C3VvQIFBKKWfIzbEuWts1H3bMgzP7rPvLN7VCoVYPiKphXdPgwIFTF1i6K42lu06was8pLmfnEuDrxW1Vo+hYM5oONctQsZTji98KSoNAKaWczRhI22k7hDQfjqy37i9V9ddQqJCQ7/wJmVk5/Lz3FEt3nmDJrjQOnr4IQNXoYDrWLEOPhrE0quh4uO38aBAopZSrnTtiDZW9cz7sWw652RAQYR1CqtIBqnaEyHiHewvGGPadvMAS297Cmr2nGd2+Cn/sVPM3laNBoJRS7pSZDsnfw94lsGcpnEu17o+Is0KhSkeIbw/BpR1u4uKVbK5k5xIR5LinUn40CJRSqqgwBk7tsUJh71LYtwIup1vLyjX4dW8hrhX4Bhbaw2oQKKVUUZWTDUc3wh5bMBxaA7lZ4O1vDZhXpaMVDjEN8z2/cDMaBEopVVxcuQAHVlmhsHcpHN9q3R8YCW2fhtsez29th9w1DLVSSqlb5RdsXahW/S7r54wTsHeZFQqh5ZzykBoESilVlIWUgQZ9rZuTOB49SSmllEfQIFBKKQ+nQaCUUh5Og0AppTycBoFSSnk4DQKllPJwGgRKKeXhNAiUUsrDFbshJkQkDTjwG1ePAk4WYjmFpajWBUW3Nq3r1mhdt6Yk1lXJGBNtb0GxC4LfQ0QSHY214U5FtS4ourVpXbdG67o1nlaXHhpSSikPp0GglFIeztOCYKK7C3CgqNYFRbc2revWaF23xqPq8qhzBEoppW7kaXsESimlrqNBoJRSHq5EBoGIdBGRXSKSIiLP2lnuLyKzbMvXiEhlF9RUUUSWiMgOEdkmIk/aadNBRNJFZKPt9oKz67I97n4R2WJ7zBvmARXL27bna7OINHFBTTXzPA8bReSciDx1XRuXPV8iMllETojI1jz3lRKR70Uk2fY10sG6Q2xtkkVkiAvq+p+I7LT9reaISISDdfP9uzuhrr+LyOE8f69uDtbN9//XCXXNylPTfhHZ6GBdpzxfjt4bXPr6MsaUqBvgDewBqgB+wCagznVtHgUm2L4fAMxyQV0xQBPb96HAbjt1dQDmueE52w9E5bO8G7AAEKAlsMYNf9NjWBfEuOX5AtoBTYCtee77L/Cs7ftngf/YWa8UsNf2NdL2faST6+oE+Ni+/4+9ugryd3dCXX8Hni7A3zrf/9/Cruu65a8BL7jy+XL03uDK11dJ3CNoDqQYY/YaY64AM4Ge17XpCXxk+342cIeIiDOLMsYcNcast31/HtgBlHfmYxainsDHxrIaiBCRGBc+/h3AHmPMb72i/HczxiwHTl93d97X0UdALzurdga+N8acNsacAb4HujizLmPMImNMtu3H1UCFwnq831NXARXk/9cpddneA/oBMwrr8QpYk6P3Bpe9vkpiEJQHDuX5OZUb33B/aWP7h0kHSrukOsB2KKoxsMbO4lYisklEFohIXReVZIBFIpIkIqPsLC/Ic+pMA3D8z+mO5+uqssaYo2D9MwNl7LRx93M3HGtvzp6b/d2d4XHbIavJDg51uPP5agscN8YkO1ju9OfruvcGl72+SmIQ2Ptkf30f2YK0cQoRCQG+AJ4yxpy7bvF6rMMfDYF3gLmuqAlobYxpAnQFHhORdtctd+fz5QfcA3xuZ7G7nq9b4c7n7s9ANvCpgyY3+7sXtvFAVaARcBTrMMz13PZ8AfeT/96AU5+vm7w3OFzNzn23/HyVxCBIBSrm+bkCcMRRGxHxAcL5bbuxt0REfLH+0J8aY768frkx5pwxJsP2/beAr4hEObsuY8wR29cTwBys3fO8CvKcOktXYL0x5vj1C9z1fOVx/OohMtvXE3bauOW5s5007AE8YGwHk69XgL97oTLGHDfG5BhjcoFJDh7PXc+XD9AHmOWojTOfLwfvDS57fZXEIFgHVBeReNunyQHA19e1+Rq4enb9PuBHR/8shcV2/PFDYIcx5nUHbcpdPVchIs2x/j6nnFxXsIiEXv0e60Tj1uuafQ0MFktLIP3qLqsLOPyU5o7n6zp5X0dDgK/stFkIdBKRSNuhkE62+5xGRLoAY4F7jDEXHbQpyN+9sOvKe16pt4PHK8j/rzPcCew0xqTaW+jM5yuf9wbXvb4K+wx4Ubhh9XLZjdX74M+2+17E+scACMA61JACrAWquKCmNli7bJuBjbZbN2A0MNrW5nFgG1ZPidXAbS6oq4rt8TbZHvvq85W3LgHesz2fW4AEF/0dg7De2MPz3OeW5wsrjI4CWVifwkZgnVdaDCTbvpaytU0APsiz7nDbay0FGOaCulKwjhtffZ1d7SEXC3yb39/dyXVNs71+NmO9ycVcX5ft5xv+f51Zl+3+qVdfV3nauuT5yue9wWWvLx1iQimlPFxJPDSklFLqFmgQKKWUh9MgUEopD6dBoJRSHk6DQCmlPJwGgVIuJNaIqfPcXYdSeWkQKKWUh9MgUMoOERkkImttY8+/LyLeIpIhIq+JyHoRWSwi0ba2jURktfw6/n+k7f5qIvKDbVC89SJS1bb5EBGZLdacAZ86e+RbpW5Gg0Cp64hIbaA/1iBjjYAc4AEgGGvcoybAMuBvtlU+BsYaYxpgXTl79f5PgfeMNSjebVhXtII1uuRTWGPOVwFaO/2XUiofPu4uQKki6A6gKbDO9mE9EGvAr1x+HZTsE+BLEQkHIowxy2z3fwR8bhuXprwxZg6AMSYTwLa9tcY2po1Ys2FVBlY6/9dSyj4NAqVuJMBHxpjnrrlT5K/XtctvfJb8DvdczvN9Dvp/qNxMDw0pdaPFwH0iUgZ+mTu2Etb/y322NgOBlcaYdOCMiLS13f8gsMxY48mnikgv2zb8RSTIpb+FUgWkn0SUuo4xZruI/AVrNiovrJEqHwMuAHVFJAlrVrv+tlWGABNsb/R7gWG2+x8E3heRF23b6OvCX0OpAtPRR5UqIBHJMMaEuLsOpQqbHhpSSikPp3sESinl4XSPQCmlPJwGgVJKeTgNAqWU8nAaBEop5eE0CJRSysP9P3H+41XoUI+nAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "from IPython import display\n",
    "import torch.utils.data as Data\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "#下载MNIST手写数据集 :包括训练集和测试集\n",
    "train_dataset = torchvision.datasets.MNIST(root='./Datasets/MNIST', train=True,  download=True, transform=transforms.ToTensor())  \n",
    "test_dataset = torchvision.datasets.MNIST(root='./Datasets/MNIST', train=False,  download=True, transform=transforms.ToTensor())  \n",
    "\n",
    "batch_size = 32  \n",
    "train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=0)  \n",
    "test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,  num_workers=0) \n",
    "\n",
    "#超参数初始化\n",
    "num_inputs=784 #28*28\n",
    "num_hiddens=256\n",
    "num_outputs=10\n",
    "#参数初始化\n",
    "W1 = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_hiddens)),dtype=torch.float32)\n",
    "b1 = torch.zeros(1,dtype=torch.float32)\n",
    "W2 = torch.tensor(np.random.normal(0,0.01,(num_hiddens,num_outputs)),dtype=torch.float32)\n",
    "b2 = torch.zeros(1,dtype=torch.float32)\n",
    "params = [W1,b2,W2,b2]\n",
    "for param in params:\n",
    "    param.requires_grad_(requires_grad = True)\n",
    "\n",
    "#x为矩阵tensor\n",
    "##对于一个小批量样本的训练而言，X=32*784，W=784*256，得到的relu函数参数：32*256\n",
    "##表示32个样本，每个样本在256个隐藏层神经元上的输出值\n",
    "def relu(x):\n",
    "    x = torch.max(x,torch.tensor(0.0))\n",
    "    return x\n",
    "def net (X):\n",
    "    #因为我们忽略了空间结构，所以我们使用reshape将每个二维图像转换为一个长度为num_inputs的向量。\n",
    "    X = X.reshape((-1,num_inputs))#X.shape为torch.Size([32, 1, 28, 28]) 展平为32*728\n",
    "    H = relu(torch.mm(X,W1)+b1)\n",
    "    return (torch.mm(H,W2)+b2)#所得结果shape为32*10,代表32个样本,分别在10个输出层神经元上的输出。\n",
    "\n",
    "loss = torch.nn.CrossEntropyLoss()\n",
    "def SGD(paras,lr,batch_size):  \n",
    "    for param in params:  \n",
    "        param.data -= lr * param.grad/batch_size\n",
    "#返回准确率以及loss\n",
    "flag=0\n",
    "def evaluate_accuracy_loss(net, data_iter):\n",
    "    acc_sum=0.0\n",
    "    loss_sum=0.0\n",
    "    n=0\n",
    "    global flag\n",
    "    for X,y in data_iter:\n",
    "        y_hat = net(X)\n",
    "        #if flag==0:print (y_hat)#测试一下y_hat是否已经softmax激活\n",
    "        #flag = 1\n",
    "        acc_sum += (y_hat.argmax(dim=1)==y).sum().item()\n",
    "        l = loss(y_hat,y)\n",
    "        loss_sum += l.item()*y.shape[0]#由于loss(y_hat,y)默认为求平均，因此*y.shape[0]意味着求和。\n",
    "        n+=y.shape[0]\n",
    "    return acc_sum/n,loss_sum/n\n",
    "#记录列表（list），存储训练集和测试集上经过每一轮次，loss的变化\n",
    "def train (net,train_iter,test_iter,loss,num_epochs,batch_size,params = None,lr=None,optimizer=None):\n",
    "    train_loss=[]\n",
    "    test_loss=[]\n",
    "    for epoch in range(num_epochs):#外循环控制循环轮次---跑完一轮，也就把数据走了一遍\n",
    "        train_l_sum=0.0#记录训练集上的损失\n",
    "        train_acc_num=0.0#记录训练集上的准确数\n",
    "        n =0.0\n",
    "        #step1在训练集上，进行小批量梯度下降更新参数\n",
    "        for X,y in train_iter:#内循环控制训练批次\n",
    "            y_hat = net(X)\n",
    "            #保证y与y_hat维度一致，否则将会发生广播\n",
    "            l = loss(y_hat,y)#这里计算出的loss是已经求和过的，l.size = torch.Size([]),即说明loss为表示*标量*的tensor`\n",
    "            #梯度清零\n",
    "            if optimizer is not None:\n",
    "                optimizer.zero_grad()\n",
    "            elif params is not None and params[0].grad is not None:\n",
    "                for param in params:\n",
    "                    param.grad.data.zero_()\n",
    "            l.backward()\n",
    "            if optimizer is None:\n",
    "                SGD(params,lr,batch_size)\n",
    "            else:\n",
    "                optimizer.step()\n",
    "            #每一个迭代周期中得到的训练集上的loss累积进来\n",
    "            train_l_sum += l.item()*y.shape[0]\n",
    "            #计算训练样本的准确率---将每个迭代周期中预测正确的样本数累积进来\n",
    "            train_acc_num += (y_hat.argmax(dim=1)==y).sum().item()#转为int类型\n",
    "            n += y.shape[0]\n",
    "        #step2 每经过一个轮次的训练， 记录训练集和测试集上的loss\n",
    "        #注意要取平均值，loss默认求了sum\n",
    "        train_loss.append(train_l_sum/n)#训练集loss\n",
    "        test_acc,test_l = evaluate_accuracy_loss(net,test_iter)\n",
    "        test_loss.append(test_l)\n",
    "        print(\"epoch %d,train_loss %.6f,test_loss %.6f,train_acc %.6f,test_acc %.6f\"%(epoch+1,train_loss[epoch],test_loss[epoch],train_acc_num/n,test_acc)) \n",
    "    return train_loss, test_loss\n",
    "lr = 0.01\n",
    "num_epochs=20\n",
    "train_loss,test_loss=train(net,train_iter,test_iter,loss,num_epochs,batch_size,params,lr)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "x=np.linspace(0,len(train_loss),len(train_loss))\n",
    "plt.plot(x,train_loss,label=\"train_loss\",linewidth=1.5)\n",
    "plt.plot(x,test_loss,label=\"test_loss\",linewidth=1.5)\n",
    "plt.xlabel(\"epoch\")\n",
    "plt.ylabel(\"loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}